diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-06 21:00:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-06 21:06:15 -0700 |
commit | 7fa6a6b42bc9d562e2b1cc765ca78d281b51f734 (patch) | |
tree | 19398616e50223b0876e6e9a987cc9908d404592 /tensorflow/contrib/lite/core | |
parent | e93a18954689b6d522560f5273f6d3320d545b2e (diff) |
Add SequenceLSTMOptions to schema to decouple the sequential Op from the LSTM.
PiperOrigin-RevId: 216066634
Diffstat (limited to 'tensorflow/contrib/lite/core')
-rw-r--r-- | tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc | 15 |
1 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index eac7db9a88..b092e5ee54 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { auto params = allocator->AllocatePOD<TfLiteLSTMParams>(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { @@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto* params = + allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>(); + if (auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { auto params = allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); |