aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/core
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 13:25:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:32:42 -0700
commitc2c8cfe22492cf7fab804d32283b623632270035 (patch)
tree6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/core
parent7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff)
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive. PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/core')
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc34
1 files changed, 29 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index e6900e0950..eac7db9a88 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params =
- allocator->AllocatePOD<TfLiteSequenceRNNParams>();
+ auto params = allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceRNNParams>();
+ if (auto* bidi_sequence_rnn_params =
+ op->builtin_options_as_BidirectionalSequenceRNNOptions()) {
+ params->activation = parse_activation(
+ bidi_sequence_rnn_params->fused_activation_function());
+ params->time_major = bidi_sequence_rnn_params->time_major();
+ params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RNN: {
TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
@@ -360,10 +371,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
+ auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -381,6 +391,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
+ if (auto* bidi_lstm_params =
+ op->builtin_options_as_BidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(bidi_lstm_params->fused_activation_function());
+ params->cell_clip = bidi_lstm_params->cell_clip();
+ params->proj_clip = bidi_lstm_params->proj_clip();
+ params->merge_outputs = bidi_lstm_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RESIZE_BILINEAR: {
auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =