diff options
author | Yu-Cheng Ling <ycling@google.com> | 2018-06-19 12:35:44 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-19 12:38:27 -0700 |
commit | 5fab6df2788937bee1cce3a4e8f5b9d1db7497ec (patch) | |
tree | ba18594841593a0b2a3eda55c076ca78c7bf0d4e /tensorflow/contrib/lite/toco/tflite | |
parent | 8f19772410ec20010e9930f9765dbd3aaeb06111 (diff) |
Support Variable Tensor API in LSTM Full kernel.
TFLite LSTM now supports 5 inputs, 18 inputs and 20 inputs.
PiperOrigin-RevId: 201222516
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 17 |
2 files changed, 12 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD index e1025c6664..a02f90988b 100644 --- a/tensorflow/contrib/lite/toco/tflite/BUILD +++ b/tensorflow/contrib/lite/toco/tflite/BUILD @@ -24,6 +24,7 @@ cc_library( deps = [ ":types", "//tensorflow/contrib/lite/schema:schema_fbs", + "//tensorflow/contrib/lite/toco:graph_transformations", "//tensorflow/contrib/lite/toco:model", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/memory", diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 669fb9fa08..c93c0a6b90 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/contrib/lite/toco/tflite/operator.h" +// TODO(ycling): Consider refactoring to extract the LSTM definition out of +// graph_transformation module. +#include "tensorflow/contrib/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/contrib/lite/toco/tflite/builtin_operator.h" #include "tensorflow/contrib/lite/toco/tflite/custom_operator.h" #include "tensorflow/contrib/lite/toco/tflite/simple_operator.h" @@ -673,18 +676,20 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, const Operator& op) const override { const auto& lstm_op = static_cast<const LstmCellOperator&>(op); + std::vector<bool> mutating_input_variables(op.inputs.size(), false); switch (lstm_op.kernel_type) { - case LstmCellOperator::KERNEL_FULL: - // TODO(ycling): Change the full kernel to use the new variable tensor - // design. This requires moving the state tensors from output to input. - return std::vector<bool>(); + case LstmCellOperator::KERNEL_FULL: { + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + break; + } case LstmCellOperator::KERNEL_BASIC: { - std::vector<bool> mutating_input_variables(op.inputs.size(), false); mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; - return mutating_input_variables; + break; } } + return mutating_input_variables; } }; |