aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/toco/tflite/operator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r--tensorflow/contrib/lite/toco/tflite/operator.cc18
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc
index 7490ab960b..a0fbb58aca 100644
--- a/tensorflow/contrib/lite/toco/tflite/operator.cc
+++ b/tensorflow/contrib/lite/toco/tflite/operator.cc
@@ -668,6 +668,24 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions,
return 2;
}
}
+
+ std::vector<bool> GetMutatingInputVariables(
+ const Operator& op) const override {
+ const auto& lstm_op = static_cast<const LstmCellOperator&>(op);
+
+ switch (lstm_op.kernel_type) {
+ case LstmCellOperator::KERNEL_FULL:
+ // TODO(ycling): Change the full kernel to use the new variable tensor
+ // design. This requires moving the state tensors from output to input.
+ return std::vector<bool>();
+ case LstmCellOperator::KERNEL_BASIC: {
+ std::vector<bool> mutating_input_variables(op.inputs.size(), false);
+ mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true;
+ mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true;
+ return mutating_input_variables;
+ }
+ }
+ }
};
class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions,