diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 13:25:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 13:32:42 -0700 |
commit | c2c8cfe22492cf7fab804d32283b623632270035 (patch) | |
tree | 6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/core | |
parent | 7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff) |
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive.
PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/core')
-rw-r--r-- | tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc | 34 |
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index e6900e0950..eac7db9a88 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = - allocator->AllocatePOD<TfLiteSequenceRNNParams>(); + auto params = allocator->AllocatePOD<TfLiteSequenceRNNParams>(); if (auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = @@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + allocator->AllocatePOD<TfLiteBidirectionalSequenceRNNParams>(); + if (auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = parse_activation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -360,10 +371,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>(); + auto params = allocator->AllocatePOD<TfLiteLSTMParams>(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = parse_activation(lstm_params->fused_activation_function()); @@ -381,6 +391,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); + if (auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_RESIZE_BILINEAR: { auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>(); if (auto* schema_params = |