diff options
author | 2018-05-22 16:31:32 -0700 | |
---|---|---|
committer | 2018-05-22 16:34:15 -0700 | |
commit | 12ea31462d02326f14475516f8290d6e224ee70d (patch) | |
tree | 0deaa1b46621dc5266daf4b964ea1c267947b030 | |
parent | 86aedb620a3a9de73b4c6e2d24763ff22aa45d03 (diff) |
Fix the LSTM test in TFLite.
PiperOrigin-RevId: 197643581
-rw-r--r-- | tensorflow/contrib/lite/build_def.bzl | 3 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_driver.cc | 20 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/tflite_driver.h | 2 |
3 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl index 9bfc0a0fbe..c8820ab29b 100644 --- a/tensorflow/contrib/lite/build_def.bzl +++ b/tensorflow/contrib/lite/build_def.bzl @@ -212,12 +212,13 @@ def generated_test_models(): "global_batch_norm", "greater", "greater_equal", - "l2_pool", "l2norm", + "l2_pool", "less", "less_equal", "local_response_norm", "log_softmax", + "lstm", "max_pool", "maximum", "mean", diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc index 58fe5bd6e4..1f07068aee 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.cc +++ b/tensorflow/contrib/lite/testing/tflite_driver.cc @@ -143,6 +143,7 @@ void TfLiteDriver::AllocateTensors() { Invalidate("Failed to allocate tensors"); return; } + ResetLSTMStateTensors(); must_allocate_tensors_ = false; } } @@ -281,5 +282,24 @@ bool TfLiteDriver::CheckResults() { return success; } +void TfLiteDriver::ResetLSTMStateTensors() { + // This is a workaround for initializing state tensors for LSTM. + // TODO(ycling): Refactoring and find a better way to initialize state + // tensors. Maybe write the reset instructions into the test data. + for (auto node_index : interpreter_->execution_plan()) { + const auto& node_and_reg = interpreter_->node_and_registration(node_index); + const auto& node = node_and_reg->first; + const auto& registration = node_and_reg->second; + if (registration.builtin_code == tflite::BuiltinOperator_LSTM && + node.outputs->size >= 2) { + // The first 2 outputs of LSTM are state tensors. + for (int i = 0; i < 2; ++i) { + int node_index = node.outputs->data[i]; + ResetTensor(node_index); + } + } + } +} + } // namespace testing } // namespace tflite diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h index 02b7de1534..5493ba3631 100644 --- a/tensorflow/contrib/lite/testing/tflite_driver.h +++ b/tensorflow/contrib/lite/testing/tflite_driver.h @@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner { string ReadOutput(int id) override { return "no-op"; } private: + void ResetLSTMStateTensors(); + class Expectation; bool use_nnapi_ = false; |