diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 7490ab960b..a0fbb58aca 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -668,6 +668,24 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, return 2; } } + + std::vector<bool> GetMutatingInputVariables( + const Operator& op) const override { + const auto& lstm_op = static_cast<const LstmCellOperator&>(op); + + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + // TODO(ycling): Change the full kernel to use the new variable tensor + // design. This requires moving the state tensors from output to input. + return std::vector<bool>(); + case LstmCellOperator::KERNEL_BASIC: { + std::vector<bool> mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[LstmCellOperator::PREV_ACTIV_INPUT] = true; + mutating_input_variables[LstmCellOperator::PREV_STATE_INPUT] = true; + return mutating_input_variables; + } + } + } }; class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, |