aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-05-22 16:31:32 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-22 16:34:15 -0700
commit12ea31462d02326f14475516f8290d6e224ee70d (patch)
tree0deaa1b46621dc5266daf4b964ea1c267947b030
parent86aedb620a3a9de73b4c6e2d24763ff22aa45d03 (diff)
Fix the LSTM test in TFLite.
PiperOrigin-RevId: 197643581
-rw-r--r--tensorflow/contrib/lite/build_def.bzl3
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.cc20
-rw-r--r--tensorflow/contrib/lite/testing/tflite_driver.h2
3 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index 9bfc0a0fbe..c8820ab29b 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -212,12 +212,13 @@ def generated_test_models():
"global_batch_norm",
"greater",
"greater_equal",
- "l2_pool",
"l2norm",
+ "l2_pool",
"less",
"less_equal",
"local_response_norm",
"log_softmax",
+ "lstm",
"max_pool",
"maximum",
"mean",
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.cc b/tensorflow/contrib/lite/testing/tflite_driver.cc
index 58fe5bd6e4..1f07068aee 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.cc
+++ b/tensorflow/contrib/lite/testing/tflite_driver.cc
@@ -143,6 +143,7 @@ void TfLiteDriver::AllocateTensors() {
Invalidate("Failed to allocate tensors");
return;
}
+ ResetLSTMStateTensors();
must_allocate_tensors_ = false;
}
}
@@ -281,5 +282,24 @@ bool TfLiteDriver::CheckResults() {
return success;
}
+void TfLiteDriver::ResetLSTMStateTensors() {
+ // This is a workaround for initializing state tensors for LSTM.
+ // TODO(ycling): Refactoring and find a better way to initialize state
+ // tensors. Maybe write the reset instructions into the test data.
+ for (auto node_index : interpreter_->execution_plan()) {
+ const auto& node_and_reg = interpreter_->node_and_registration(node_index);
+ const auto& node = node_and_reg->first;
+ const auto& registration = node_and_reg->second;
+ if (registration.builtin_code == tflite::BuiltinOperator_LSTM &&
+ node.outputs->size >= 2) {
+ // The first 2 outputs of LSTM are state tensors.
+ for (int i = 0; i < 2; ++i) {
+ int node_index = node.outputs->data[i];
+ ResetTensor(node_index);
+ }
+ }
+ }
+}
+
} // namespace testing
} // namespace tflite
diff --git a/tensorflow/contrib/lite/testing/tflite_driver.h b/tensorflow/contrib/lite/testing/tflite_driver.h
index 02b7de1534..5493ba3631 100644
--- a/tensorflow/contrib/lite/testing/tflite_driver.h
+++ b/tensorflow/contrib/lite/testing/tflite_driver.h
@@ -48,6 +48,8 @@ class TfLiteDriver : public TestRunner {
string ReadOutput(int id) override { return "no-op"; }
private:
+ void ResetLSTMStateTensors();
+
class Expectation;
bool use_nnapi_ = false;