diff options
Diffstat (limited to 'tensorflow/contrib/lite/toco/tflite/operator.cc')
-rw-r--r-- | tensorflow/contrib/lite/toco/tflite/operator.cc | 31 |
1 files changed, 29 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/toco/tflite/operator.cc b/tensorflow/contrib/lite/toco/tflite/operator.cc index 84a5410839..a8518adefc 100644 --- a/tensorflow/contrib/lite/toco/tflite/operator.cc +++ b/tensorflow/contrib/lite/toco/tflite/operator.cc @@ -626,11 +626,21 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, flatbuffers::Offset<TfLiteOptions> WriteOptions( const TocoOperator& op, flatbuffers::FlatBufferBuilder* builder) const override { + ::tflite::LSTMKernelType kernel_type; + switch (op.kernel_type) { + case LstmCellOperator::KERNEL_BASIC: + kernel_type = ::tflite::LSTMKernelType_BASIC; + break; + case LstmCellOperator::KERNEL_FULL: + kernel_type = ::tflite::LSTMKernelType_FULL; + break; + } + // Current toco converter only supports tanh, no clip. return ::tflite::CreateLSTMOptions(*builder, /*fused_activation_function=*/ ::tflite::ActivationFunctionType_TANH, /*cell_clip=*/0.0, - /*proj_clip=*/0.0); + /*proj_clip=*/0.0, kernel_type); } void ReadOptions(const TfLiteOptions& options, @@ -638,9 +648,26 @@ class Lstm : public BuiltinOperator<LstmCellOperator, ::tflite::LSTMOptions, // Only support tanh activation, so check that tflite type is tanh. CHECK(options.fused_activation_function() == ::tflite::ActivationFunctionType_TANH); + + switch (options.kernel_type()) { + case ::tflite::LSTMKernelType_BASIC: + op->kernel_type = LstmCellOperator::KERNEL_BASIC; + break; + case ::tflite::LSTMKernelType_FULL: + op->kernel_type = LstmCellOperator::KERNEL_FULL; + break; + } } - int GetVersion(const Operator& op) const override { return 1; } + int GetVersion(const Operator& op) const override { + const auto& lstm_op = static_cast<const LstmCellOperator&>(op); + switch (lstm_op.kernel_type) { + case LstmCellOperator::KERNEL_FULL: + return 1; + case LstmCellOperator::KERNEL_BASIC: + return 2; + } + } }; class Mean : public BuiltinOperator<MeanOperator, ::tflite::MeanOptions, |