diff options
author | 2018-08-23 13:42:01 -0700 | |
---|---|---|
committer | 2018-08-23 13:45:36 -0700 | |
commit | 9a774e4d2d31443ea694938bec41237b4d6bcf02 (patch) | |
tree | eba93b9e5a8c96c5c40b2aff448e01690eb8a0db /tensorflow/contrib/lite/testing | |
parent | 6fe361f80d4277ea879b3182e1d7148a65a8ca21 (diff) |
Remove 18-input/3-output LSTM in favor of 20-input/1-output LSTM that supports
state API.
PiperOrigin-RevId: 209991722
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r-- | tensorflow/contrib/lite/testing/generate_examples.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_driver.cc | 22 |
2 files changed, 1 insertions, 23 deletions
diff --git a/tensorflow/contrib/lite/testing/generate_examples.py b/tensorflow/contrib/lite/testing/generate_examples.py index 599c82940e..50586a8e5d 100644 --- a/tensorflow/contrib/lite/testing/generate_examples.py +++ b/tensorflow/contrib/lite/testing/generate_examples.py @@ -2378,7 +2378,7 @@ def make_lstm_tests(zip_path): "time_step_size": [1], "input_vec_size": [3], "num_cells": [4], - "split_tflite_lstm_inputs": [True, False], + "split_tflite_lstm_inputs": [False], }, ] diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 4dacf9c84b..1836eb53b9 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -302,28 +302,6 @@ bool TfLiteDriver::CheckResults() { void TfLiteDriver::ResetLSTMStateTensors() { interpreter_->ResetVariableTensorsToZero(); - - // Below is a workaround for initializing state tensors for LSTM. - // TODO(ycling): Remove the code below after nobody is using the 18-inputs - // definition. - for (auto node_index : interpreter_->execution_plan()) { - const auto& node_and_reg = interpreter_->node_and_registration(node_index); - const auto& node = node_and_reg->first; - const auto& registration = node_and_reg->second; - - if (registration.builtin_code == tflite::BuiltinOperator_LSTM) { - const auto* params = - reinterpret_cast<const TfLiteLSTMParams*>(node.builtin_data); - if (params->kernel_type == kTfLiteLSTMFullKernel && - node.inputs->size == 18 && node.outputs->size >= 2) { - // The first 2 outputs of LSTM are state tensors. - for (int i = 0; i < 2; ++i) { - int node_index = node.outputs->data[i]; - ResetTensor(node_index); - } - } - } - } } } // namespace testing |