From 854ae599743a1e92a31ad49cfe42c6454cefd3b9 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Oct 2018 20:05:22 -0700 Subject: Use Ophints to support TfLite UnidirectionaSequenceLstm and add an e2e test. Support peephole and num_proj as well. PiperOrigin-RevId: 216467578 --- tensorflow/contrib/lite/toco/tflite/operator.cc | 39 +++++++++++++++++++++++++ 1 file changed, 39 insertions(+) (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc') diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index ed37535fe0..e08a61d357 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -741,6 +741,42 @@ class Lstm : public BuiltinOperator { + public: + using BuiltinOperator::BuiltinOperator; + flatbuffers::Offset WriteOptions( + const TocoOperator& op, + flatbuffers::FlatBufferBuilder* builder) const override { + // Current toco converter only supports tanh, no clip. + return ::tflite::CreateUnidirectionalSequenceLSTMOptions( + *builder, /*fused_activation_function=*/ + ::tflite::ActivationFunctionType_TANH, + /*cell_clip=*/0.0, + /*proj_clip=*/0.0); + } + + void ReadOptions(const TfLiteOptions& options, + TocoOperator* op) const override { + // Only support tanh activation, so check that tflite type is tanh. + DCHECK(options.fused_activation_function() == + ::tflite::ActivationFunctionType_TANH); + } + + int GetVersion(const Operator& op) const override { return 1; } + + std::vector GetMutatingInputVariables( + const Operator& op) const override { + std::vector mutating_input_variables(op.inputs.size(), false); + mutating_input_variables[kInputActivationStateTensor] = true; + mutating_input_variables[kInputCellStateTensor] = true; + return mutating_input_variables; + } +}; + class Mean : public BuiltinOperator { public: @@ -1435,6 +1471,9 @@ std::vector> BuildOperatorList( OperatorType::kFakeQuant)); ops.push_back( MakeUnique(::tflite::BuiltinOperator_PACK, OperatorType::kPack)); + ops.emplace_back(MakeUnique( + ::tflite::BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM, + OperatorType::kUnidirectionalSequenceLstm)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_ONE_HOT, OperatorType::kOneHot)); ops.push_back(MakeUnique(::tflite::BuiltinOperator_UNPACK, -- cgit v1.2.3