aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite
diff options
context:
space:
mode:
authorGravatar Yu-Cheng Ling <ycling@google.com>2018-06-19 12:35:44 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-19 12:38:27 -0700
commit5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch)
treeba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib/lite/toco/tflite
parent8f19772410ec20010e9930f9765dbd3aaeb06111 (diff)
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs. PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc17
2 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index e1025c6664..a02f90988b 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -24,6 +24,7 @@ cc_library(
deps = [
":types",
"//tensorflow/contrib/lite/schema:schema_fbs",
+ "//tensorflow/contrib/lite/toco:graph_transformations",
"//tensorflow/contrib/lite/toco:model",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/memory",
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 669fb9fa08..c93c0a6b90 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -14,6 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/operator.h"
+// TODO(ycling): Consider refactoring to extract the LSTM definition out of
+// graph_transformation module.
+#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h"
#include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/custom_operator.h"
#include "tensorflow/contrib/lite/toco/tflite/simple_operator.h"
@@ -673,18 +676,20 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
const Operator& op) const override {
const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
switch (lstm_op.kernel_type) {
- case LstmCellOperator::KERNEL_FULL:
- // TODO(ycling): Change the full kernel to use the new variable tensor
- // design. This requires moving the state tensors from output to input.
- return std::vector<bool>();
+ case LstmCellOperator::KERNEL_FULL: {
+ mutating_input_variables[kInputActivationStateTensor] = true;
+ mutating_input_variables[kInputCellStateTensor] = true;
+ break;
+ }
case LstmCellOperator::KERNEL_BASIC: {
- std::vector<bool> mutating_input_variables(op.inputs.size(), false);
mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
- return mutating_input_variables;
+ break;
}
}
+ return mutating_input_variables;
}
};