diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 39 |
1 files changed, 39 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index ed37535fe0..e08a61d357 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, } }; +class UnidirectionalSequenceLstm + : public BuiltinOperator< + UnidirectionalSequenceLstmOperator, + ::tflite::UnidirectionalSequenceLSTMOptions, + ::tflite::BuiltinOptions_UnidirectionalSequenceLSTMOptions> { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset<TfLiteOptions> WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateUnidirectionalSequenceLSTMOptions( + *builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + } + + int GetVersion(const Operator& op) const override { return 1; } + + std::vector<bool> GetMutatingInputVariables( + const Operator& op) const override { + std::vector<bool> mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + return mutating_input_variables; + } +}; + class Mean : public BuiltinOperator<MeanOperator, ::tflite::ReducerOptions, ::tflite::BuiltinOptions_ReducerOptions> { public: @@ -1435,6 +1471,9 @@ std::vector<std::unique_ptr<BaseOperator>> BuildOperatorList( OperatorType::kFakeQuant)); ops.push_back( MakeUnique<Pack>(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back(MakeUnique<UnidirectionalSequenceLstm>( + ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + OperatorType::kUnidirectionalSequenceLstm)); ops.push_back(MakeUnique<OneHot>(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); ops.push_back(MakeUnique<Unpack>(::tflite::BuiltinOperator_UNPACK, |