aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc')
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc15
1 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eac7db9a88..b092e5ee54 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
@@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto* params =
+ allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>();
+ if (auto* seq_lstm_params =
+ op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(seq_lstm_params->fused_activation_function());
+ params->cell_clip = seq_lstm_params->cell_clip();
+ params->proj_clip = seq_lstm_params->proj_clip();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();