aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
authorGravatar Alan Chiao <alanchiao@google.com>2018-08-23 13:42:01 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 13:45:36 -0700
commit9a774e4d2d31443ea694938bec41237b4d6bcf02 (patch)
treeeba93b9e5a8c96c5c40b2aff448e01690eb8a0db /tensorflow/contrib/lite/testing
parent6fe361f80d4277ea879b3182e1d7148a65a8ca21 (diff)
Remove 18-input/3-output LSTM in favor of 20-input/1-output LSTM that supports
state API. PiperOrigin-RevId: 209991722
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/generate_examples.py2
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc22
2 files changed, 1 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py
index 599c82940e..50586a8e5d 100644
--- a/tensorflow/contrib/lite/testing/generate_examples.py
+++ b/tensorflow/contrib/lite/testing/generate_examples.py
@@ -2378,7 +2378,7 @@ def make_lstm_tests(zip_path):
"time_step_size": [1],
"input_vec_size": [3],
"num_cells": [4],
- "split_tflite_lstm_inputs": [True, False],
+ "split_tflite_lstm_inputs": [False],
},
]
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 4dacf9c84b..1836eb53b9 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() {
void TfLiteDriver::ResetLSTMStateTensors() {
interpreter_->ResetVariableTensorsToZero();
-
- // Below is a workaround for initializing state tensors for LSTM.
- // TODO(ycling): Remove the code below after nobody is using the 18-inputs
- // definition.
- for (auto node_index : interpreter_->execution_plan()) {
- const auto& node_and_reg = interpreter_->node_and_registration(node_index);
- const auto& node = node_and_reg->first;
- const auto& registration = node_and_reg->second;
-
- if (registration.builtin_code == tflite::BuiltinOperator_LSTM) {
- const auto* params =
- reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data);
- if (params->kernel_type == kTfLiteLSTMFullKernel &&
- node.inputs->size == 18 && node.outputs->size >= 2) {
- // The first 2 outputs of LSTM are state tensors.
- for (int i = 0; i < 2; ++i) {
- int node_index = node.outputs->data[i];
- ResetTensor(node_index);
- }
- }
- }
- }
}
} // namespace testing