diff options
Diffstat (limited to 'tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc')
-rw-r--r-- | tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index eac7db9a88..b092e5ee54 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { auto params = allocator->AllocatePOD<TfLiteLSTMParams>(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { @@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto* params = + allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>(); + if (auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { auto params = allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); |