aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-06 21:00:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-06 21:06:15 -0700
commit7fa6a6b42bc9d562e2b1cc765ca78d281b51f734 (patch)
tree19398616e50223b0876e6e9a987cc9908d404592 /tensorflow/contrib/lite/core
parente93a18954689b6d522560f5273f6d3320d545b2e (diff)
Add SequenceLSTMOptions to schema to decouple the sequential Op from the LSTM.
PiperOrigin-RevId: 216066634
Diffstat (limited to 'tensorflow/contrib/lite/core')
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc15
1 files changed, 14 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index eac7db9a88..b092e5ee54 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -371,7 +371,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
@@ -391,6 +390,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto* params =
+ allocator->AllocatePOD<TfLiteUnidirectionalSequenceLSTMParams>();
+ if (auto* seq_lstm_params =
+ op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(seq_lstm_params->fused_activation_function());
+ params->cell_clip = seq_lstm_params->cell_clip();
+ params->proj_clip = seq_lstm_params->proj_clip();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
+
case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
auto params =
allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();