aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 13:25:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:32:42 -0700
commitc2c8cfe22492cf7fab804d32283b623632270035 (patch)
tree6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite
parent7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff)
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive. PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data.h16
-rw-r--r--tensorflow/contrib/lite/c/builtin_op_data_test.cc2
-rw-r--r--tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc34
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc116
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc186
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc85
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc56
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs12
-rwxr-xr-xtensorflow/contrib/lite/schema/schema_generated.h243
9 files changed, 640 insertions, 110 deletions
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h
index be9d551ee4..44daf7adaa 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data.h
+++ b/tensorflow/contrib/lite/c/builtin_op_data.h
@@ -99,6 +99,12 @@ typedef struct {
TfLiteFusedActivation activation;
} TfLiteSequenceRNNParams;
+typedef struct {
+ bool time_major;
+ TfLiteFusedActivation activation;
+ bool merge_outputs;
+} TfLiteBidirectionalSequenceRNNParams;
+
typedef enum {
kTfLiteFullyConnectedWeightsFormatDefault = 0,
kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1,
@@ -181,6 +187,16 @@ typedef struct {
} TfLiteLSTMParams;
typedef struct {
+ // Parameters for the LSTM kernel.
+ TfLiteFusedActivation activation;
+ float cell_clip;
+ float proj_clip;
+
+ // If true, store the outputs of both directions in the first output.
+ bool merge_outputs;
+} TfLiteBidirectionalSequenceLSTMParams;
+
+typedef struct {
bool align_corners;
} TfLiteResizeBilinearParams;
diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
index 4d0ba75e68..ba458b4252 100644
--- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc
+++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc
@@ -73,6 +73,8 @@ TEST(IntArray, CanCompileStructs) {
TfLiteFakeQuantParams fake_quant_params;
TfLitePackParams pack_params;
TfLiteOneHotParams one_hot_params;
+ TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params;
+ TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params;
}
} // namespace tflite
diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
index e6900e0950..eac7db9a88 100644
--- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
+++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc
@@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
- TfLiteSequenceRNNParams* params =
- allocator->AllocatePOD<TfLiteSequenceRNNParams>();
+ auto params = allocator->AllocatePOD<TfLiteSequenceRNNParams>();
if (auto* sequence_rnn_params =
op->builtin_options_as_SequenceRNNOptions()) {
params->activation =
@@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceRNNParams>();
+ if (auto* bidi_sequence_rnn_params =
+ op->builtin_options_as_BidirectionalSequenceRNNOptions()) {
+ params->activation = parse_activation(
+ bidi_sequence_rnn_params->fused_activation_function());
+ params->time_major = bidi_sequence_rnn_params->time_major();
+ params->merge_outputs = bidi_sequence_rnn_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RNN: {
TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>();
if (auto* rnn_params = op->builtin_options_as_RNNOptions()) {
@@ -360,10 +371,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
- case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM:
case BuiltinOperator_LSTM: {
- TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>();
+ auto params = allocator->AllocatePOD<TfLiteLSTMParams>();
if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) {
params->activation =
parse_activation(lstm_params->fused_activation_function());
@@ -381,6 +391,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = reinterpret_cast<void*>(params);
break;
}
+ case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: {
+ auto params =
+ allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>();
+ if (auto* bidi_lstm_params =
+ op->builtin_options_as_BidirectionalSequenceLSTMOptions()) {
+ params->activation =
+ parse_activation(bidi_lstm_params->fused_activation_function());
+ params->cell_clip = bidi_lstm_params->cell_clip();
+ params->proj_clip = bidi_lstm_params->proj_clip();
+ params->merge_outputs = bidi_lstm_params->merge_outputs();
+ }
+ *builtin_data = reinterpret_cast<void*>(params);
+ break;
+ }
case BuiltinOperator_RESIZE_BILINEAR: {
auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>();
if (auto* schema_params =
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
index 66b947771c..0532528f52 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc
@@ -119,7 +119,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set.
// Temporary tensors.
enum TemporaryTensor {
@@ -162,7 +162,8 @@ TfLiteStatus CheckLstmTensorDimensions(
int input_gate_bias_tensor, int forget_gate_bias_tensor,
int cell_gate_bias_tensor, int output_gate_bias_tensor,
int projection_weights_tensor, int projection_bias_tensor) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Making sure clipping parameters have valid values.
// == 0 means no clipping
@@ -347,10 +348,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context,
// tensors. Also check that the size of the input tensors match each other.
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 48);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
// Inferring batch size, number of outputs and sequence length and
// number of cells from the input tensors.
@@ -368,6 +372,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1],
n_input);
+ const TfLiteTensor* bw_input_to_output_weights =
+ GetInput(context, node, kBwInputToOutputWeightsTensor);
+ const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
+ n_input);
+
const TfLiteTensor* fw_recurrent_to_output_weights =
GetInput(context, node, kFwRecurrentToOutputWeightsTensor);
TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2);
@@ -375,6 +386,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
n_fw_cell);
const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1];
+ const TfLiteTensor* bw_recurrent_to_output_weights =
+ GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
+ TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
+ n_bw_cell);
+ const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
+
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_fw_output,
@@ -440,7 +458,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3);
fw_output_size->data[0] = max_time;
fw_output_size->data[1] = n_batch;
- fw_output_size->data[2] = n_fw_output;
+ fw_output_size->data[2] =
+ params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output;
TF_LITE_ENSURE_OK(context,
context->ResizeTensor(context, fw_output, fw_output_size));
@@ -479,39 +498,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer,
fw_scratch_buffer_size));
// Same for the backward cell.
- const TfLiteTensor* bw_input_to_output_weights =
- GetInput(context, node, kBwInputToOutputWeightsTensor);
- const int n_bw_cell = bw_input_to_output_weights->dims->data[0];
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1],
- n_input);
-
- const TfLiteTensor* bw_recurrent_to_output_weights =
- GetInput(context, node, kBwRecurrentToOutputWeightsTensor);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2);
- TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0],
- n_bw_cell);
- const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1];
// Check that input tensor dimensions matches with each other.
TF_LITE_ENSURE_OK(
context, CheckInputTensorDimensions(context, node, n_input, n_bw_output,
n_bw_cell));
- // Get the pointer to output, activation_state and cell_state buffer tensors.
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ // Get the pointer to activation_state and cell_state buffer tensors.
TfLiteTensor* bw_activation_state =
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
// Resize the output tensors.
- TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
- bw_output_size->data[0] = max_time;
- bw_output_size->data[1] = n_batch;
- bw_output_size->data[2] = n_bw_output;
- TF_LITE_ENSURE_OK(context,
- context->ResizeTensor(context, bw_output, bw_output_size));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3);
+ bw_output_size->data[0] = max_time;
+ bw_output_size->data[1] = n_batch;
+ bw_output_size->data[2] = n_bw_output;
+ TF_LITE_ENSURE_OK(
+ context, context->ResizeTensor(context, bw_output, bw_output_size));
+ }
// Check the shape of input state tensors.
// These tensor may be 1D or 2D. It's fine as long as the total size is
@@ -705,7 +713,7 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state,
TfLiteTensor* cell_state, TfLiteTensor* output) {
const int max_time = input->dims->data[0];
@@ -771,12 +779,13 @@ TfLiteStatus EvalFloat(
// Loop through the sequence.
const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
+ const int output_step = n_batch * output->dims->data[2];
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr_time = output->data.f + t_rel * output_step;
+ float* output_ptr_time =
+ output->data.f + t_rel * output_step + output_offset;
kernel_utils::LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f,
@@ -816,7 +825,7 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias,
const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias,
const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias,
- const TfLiteLSTMParams* params, bool forward_sequence,
+ const TfLiteLSTMParams* params, bool forward_sequence, int output_offset,
TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors,
TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights,
TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
@@ -972,12 +981,12 @@ TfLiteStatus EvalHybrid(
// Feed the sequence into the LSTM step-by-step.
const int input_step = n_batch * n_input;
- const int output_step = n_batch * n_output;
+ const int output_step = n_batch * output->dims->data[2];
for (int t = 0; t < max_time; t++) {
// If this is the forward_sequence, step forward, otherwise step backwards.
const int t_rel = forward_sequence ? t : max_time - t - 1;
const float* input_ptr = input->data.f + t_rel * input_step;
- float* output_ptr = output->data.f + t_rel * output_step;
+ float* output_ptr = output->data.f + t_rel * output_step + output_offset;
kernel_utils::LstmStepWithAuxInput(
input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale,
@@ -1011,7 +1020,8 @@ TfLiteStatus EvalHybrid(
// The LSTM Op engine.
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>(
+ node->builtin_data);
// Input tensor.
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
@@ -1107,7 +1117,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwInputActivationStateTensor);
TfLiteTensor* bw_cell_state =
GetVariableInput(context, node, kBwInputCellStateTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
// Temporary tensors.
TfLiteTensor* fw_scratch_buffer =
@@ -1135,6 +1147,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* bw_aux_input_to_output_weights =
GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor);
+ // Populate a TfLiteLSTMParams struct for the evaluation functions.
+ TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip,
+ params->proj_clip, kTfLiteLSTMFullKernel};
+
+ const int bw_output_offset =
+ params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0;
+ const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output;
+
switch (fw_input_to_output_weights->type) {
case kTfLiteFloat32: {
TfLiteStatus fw_pass_status = EvalFloat(
@@ -1147,9 +1167,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state,
- fw_cell_state, fw_output);
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ fw_activation_state, fw_cell_state, fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
TfLiteStatus bw_pass_status = EvalFloat(
@@ -1162,9 +1182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights,
bw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state,
- bw_cell_state, bw_output);
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ bw_activation_state, bw_cell_state, actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
@@ -1198,10 +1218,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, fw_input_gate_bias,
fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias,
- fw_projection_weights, fw_projection_bias, params,
- /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, fw_activation_state_quantized,
+ fw_projection_weights, fw_projection_bias, &lstm_params,
+ /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, fw_activation_state_quantized,
fw_cell_state_quantized, fw_activation_state, fw_cell_state,
fw_output);
TF_LITE_ENSURE_OK(context, fw_pass_status);
@@ -1216,12 +1236,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights,
fw_aux_input_to_output_weights, bw_input_gate_bias,
bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias,
- bw_projection_weights, bw_projection_bias, params,
- /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors,
- prod_scaling_factors, recovered_cell_weights, input_quantized,
- aux_input_quantized, bw_activation_state_quantized,
+ bw_projection_weights, bw_projection_bias, &lstm_params,
+ /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer,
+ scaling_factors, prod_scaling_factors, recovered_cell_weights,
+ input_quantized, aux_input_quantized, bw_activation_state_quantized,
bw_cell_state_quantized, bw_activation_state, bw_cell_state,
- bw_output);
+ actual_bw_output);
TF_LITE_ENSURE_OK(context, bw_pass_status);
return kTfLiteOk;
}
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
index 74ba8021c2..9cc04907e1 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc
@@ -35,8 +35,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output,
int sequence_length, bool use_cifg,
bool use_peephole, bool use_projection_weights,
- bool use_projection_bias, float cell_clip,
- float proj_clip,
+ bool use_projection_bias, bool merge_outputs,
+ float cell_clip, float proj_clip,
const std::vector<std::vector<int>>& input_shapes)
: n_batch_(n_batch),
n_input_(n_input),
@@ -175,7 +175,9 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
aux_input_ = AddNullInput();
fw_aux_input_to_input_weights_ = AddNullInput();
@@ -188,9 +190,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel {
bw_aux_input_to_output_weights_ = AddNullInput();
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM,
- BuiltinOptions_LSTMOptions,
- CreateLSTMOptions(builder_, ActivationFunctionType_TANH,
- cell_clip, proj_clip)
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ CreateBidirectionalSequenceLSTMOptions(
+ builder_, ActivationFunctionType_TANH, cell_clip,
+ proj_clip, merge_outputs)
.Union());
BuildInterpreter(input_shapes);
}
@@ -380,7 +383,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -526,6 +530,162 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) {
ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with a single merged output tensor.
+TEST(LSTMOpTest, BlackBoxTestMergedOutput) {
+ const int n_batch = 1;
+ const int n_input = 2;
+ // n_cell and n_output have the same size when there is no projection.
+ const int n_cell = 4;
+ const int n_output = 4;
+ const int sequence_length = 3;
+
+ BidirectionalLSTMOpModel lstm(
+ n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
+ /*use_peephole=*/false, /*use_projection_weights=*/false,
+ /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
+ {
+ {sequence_length, n_batch, n_input}, // input tensor
+
+ // Forward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ // Backward cell
+ {n_cell, n_input}, // input_to_input_weight tensor
+ {n_cell, n_input}, // input_to_forget_weight tensor
+ {n_cell, n_input}, // input_to_cell_weight tensor
+ {n_cell, n_input}, // input_to_output_weight tensor
+
+ {n_cell, n_output}, // recurrent_to_input_weight tensor
+ {n_cell, n_output}, // recurrent_to_forget_weight tensor
+ {n_cell, n_output}, // recurrent_to_cell_weight tensor
+ {n_cell, n_output}, // recurrent_to_output_weight tensor
+
+ {0}, // cell_to_input_weight tensor
+ {0}, // cell_to_forget_weight tensor
+ {0}, // cell_to_output_weight tensor
+
+ {n_cell}, // input_gate_bias tensor
+ {n_cell}, // forget_gate_bias tensor
+ {n_cell}, // cell_bias tensor
+ {n_cell}, // output_gate_bias tensor
+
+ {0, 0}, // projection_weight tensor
+ {0}, // projection_bias tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, n_output}, // activation_state tensor
+ {n_batch, n_cell}, // cell_state tensor
+
+ {n_batch, sequence_length, 0}, // aux_input tensor
+ {n_cell, 0}, // aux_fw_input_to_input tensor
+ {n_cell, 0}, // aux_fw_input_to_forget tensor
+ {n_cell, 0}, // aux_fw_input_to_cell tensor
+ {n_cell, 0}, // aux_fw_input_to_output tensor
+ {n_cell, 0}, // aux_bw_input_to_input tensor
+ {n_cell, 0}, // aux_bw_input_to_forget tensor
+ {n_cell, 0}, // aux_bw_input_to_cell tensor
+ {n_cell, 0}, // aux_bw_input_to_output tensor
+ });
+
+ lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589,
+ -0.34550029, 0.04266912, -0.15680569,
+ -0.34856534, 0.43890524});
+
+ lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163,
+ -0.20583314, 0.44344562, 0.22077113,
+ -0.29909778});
+
+ lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935,
+ -0.31343272, -0.40032279, 0.44781327,
+ 0.01387155, -0.35593212});
+
+ lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829,
+ 0.40525138, 0.44272184, 0.03897077, -0.1556896,
+ 0.19487578});
+
+ lstm.SetInputGateBias({0., 0., 0., 0.});
+
+ lstm.SetCellBias({0., 0., 0., 0.});
+
+ lstm.SetForgetGateBias({1., 1., 1., 1.});
+
+ lstm.SetOutputGateBias({0., 0., 0., 0.});
+
+ lstm.SetRecurrentToInputWeights(
+ {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324,
+ -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322,
+ -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296});
+
+ lstm.SetRecurrentToCellWeights(
+ {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841,
+ -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659,
+ -0.46367589, 0.26016325, -0.03894562, -0.16368064});
+
+ lstm.SetRecurrentToForgetWeights(
+ {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892,
+ -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436,
+ 0.28053468, 0.01560611, -0.20127171, -0.01140004});
+
+ lstm.SetRecurrentToOutputWeights(
+ {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793,
+ 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421,
+ -0.51818722, -0.15390486, 0.0468148, 0.39922136});
+
+ // Input should have n_input * sequence_length many values.
+ static float lstm_input[] = {2., 3., 3., 4., 1., 1.};
+ static float lstm_fw_golden_output[] = {
+ -0.02973187, 0.1229473, 0.20885126, -0.15358765,
+ -0.03716109, 0.12507336, 0.41193449, -0.20860538,
+ -0.15053082, 0.09120187, 0.24278517, -0.12222792};
+ static float lstm_bw_golden_output[] = {
+ -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838,
+ 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124};
+
+ float* batch0_start = lstm_input;
+ float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length();
+
+ lstm.SetInput(0, batch0_start, batch0_end);
+
+ lstm.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int k = 0; k < lstm.sequence_length(); k++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_fw_golden_output + k * lstm.num_fw_outputs(),
+ lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs());
+ merged_expected.insert(
+ merged_expected.end(),
+ lstm_bw_golden_output + k * lstm.num_bw_outputs(),
+ lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs());
+ }
+ EXPECT_THAT(lstm.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
const int n_batch = 1;
const int n_input = 2;
@@ -537,7 +697,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/false, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -696,7 +857,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -845,7 +1007,8 @@ TEST(LSTMOpTest,
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true,
/*use_peephole=*/true, /*use_projection_weights=*/false,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
@@ -994,7 +1157,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) {
BidirectionalLSTMOpModel lstm(
n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false,
/*use_peephole=*/true, /*use_projection_weights=*/true,
- /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0,
+ /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0,
+ /*proj_clip=*/0.0,
{
{sequence_length, n_batch, n_input}, // input tensor
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index 2f896c5289..9f62ac3f2c 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional.
constexpr int kBwAuxWeightsTensor = 11; // Optional.
// Output tensors.
constexpr int kFwOutputTensor = 0;
-constexpr int kBwOutputTensor = 1;
+constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false.
// Temporary tensors.
enum TemporaryTensor {
@@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) {
}
TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
+
// Check we have all the inputs and outputs we need.
TF_LITE_ENSURE_EQ(context, node->inputs->size, 12);
- TF_LITE_ENSURE_EQ(context, node->outputs->size, 2);
+ TF_LITE_ENSURE_EQ(context, node->outputs->size,
+ params->merge_outputs ? 1 : 2);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -142,9 +146,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
bw_aux_input_weights->dims->data[1]);
}
- TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
-
const bool is_hybrid_op =
(fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32);
@@ -233,18 +234,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
}
// Resize outputs.
+ TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
fw_output_size_array->data[0] = batch_size;
fw_output_size_array->data[1] = max_time;
- fw_output_size_array->data[2] = fw_num_units;
+ fw_output_size_array->data[2] =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
context, context->ResizeTensor(context, fw_output, fw_output_size_array));
- TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
- bw_output_size_array->data[0] = batch_size;
- bw_output_size_array->data[1] = max_time;
- bw_output_size_array->data[2] = bw_num_units;
- TF_LITE_ENSURE_OK(
- context, context->ResizeTensor(context, bw_output, bw_output_size_array));
+ if (!params->merge_outputs) {
+ TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3);
+ bw_output_size_array->data[0] = batch_size;
+ bw_output_size_array->data[1] = max_time;
+ bw_output_size_array->data[2] = bw_num_units;
+ TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output,
+ bw_output_size_array));
+ }
return kTfLiteOk;
}
@@ -256,9 +262,9 @@ TfLiteStatus EvalFloat(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights,
const TfLiteTensor* bw_aux_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state,
- TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -281,10 +287,15 @@ TfLiteStatus EvalFloat(
? bw_aux_input_weights->data.f
: nullptr;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -292,8 +303,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
@@ -304,6 +314,10 @@ TfLiteStatus EvalFloat(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -311,8 +325,7 @@ TfLiteStatus EvalFloat(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
@@ -331,11 +344,12 @@ TfLiteStatus EvalHybrid(
const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias,
const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights,
const TfLiteTensor* aux_bw_input_weights,
- const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors,
- TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized,
- TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state,
- TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized,
- TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
+ const TfLiteBidirectionalSequenceRNNParams* params,
+ TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized,
+ TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized,
+ TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
+ TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
+ TfLiteTensor* bw_output) {
const int batch_size = input->dims->data[0];
const int max_time = input->dims->data[1];
const int input_size = input->dims->data[2];
@@ -384,10 +398,15 @@ TfLiteStatus EvalHybrid(
reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8);
float* scaling_factors_ptr = scaling_factors->data.f;
+ const int fw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
+ const int bw_output_step =
+ params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
for (int b = 0; b < batch_size; b++) {
// Forward cell.
float* fw_hidden_state_ptr_batch =
fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -395,8 +414,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
@@ -411,6 +429,10 @@ TfLiteStatus EvalHybrid(
// Backward cell.
float* bw_hidden_state_ptr_batch =
bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
input->data.f + b * input_size * max_time + s * input_size;
@@ -418,8 +440,7 @@ TfLiteStatus EvalHybrid(
(aux_input != nullptr)
? aux_input->data.f + b * input_size * max_time + s * input_size
: nullptr;
- float* output_ptr_batch =
- bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
@@ -436,8 +457,8 @@ TfLiteStatus EvalHybrid(
}
TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
- const auto* params =
- reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data);
+ const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>(
+ node->builtin_data);
const TfLiteTensor* input = GetInput(context, node, kInputTensor);
const TfLiteTensor* fw_input_weights =
@@ -465,7 +486,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
GetVariableInput(context, node, kBwHiddenStateTensor);
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
- TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor);
+ TfLiteTensor* bw_output = params->merge_outputs
+ ? nullptr
+ : GetOutput(context, node, kBwOutputTensor);
switch (fw_input_weights->type) {
case kTfLiteFloat32:
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 3e34ba6196..f555c472f5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size)
+ int bw_units, int input_size, bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel {
aux_bw_weights_ = AddNullInput();
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_SequenceRNNOptions,
- CreateSequenceRNNOptions(builder_, /*time_major=*/false,
- ActivationFunctionType_RELU)
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, /*time_major=*/false,
+ ActivationFunctionType_RELU, merge_outputs)
.Union());
BuildInterpreter({
{batches_, sequence_len_, input_size_}, // input
@@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with merged outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*merge_outputs=*/true);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int bid = 0; bid < rnn.num_batches(); bid++) {
+ for (int step = 0; step < rnn.sequence_len(); step++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_fw_output + rnn.num_fw_units() * step,
+ rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_bw_output + rnn.num_bw_units() * step,
+ rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
+ }
+ }
+ EXPECT_THAT(rnn.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
// Check that if the input sequence is reversed the outputs are the same just
// forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 3da3188c3a..ff8430827c 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -248,6 +248,8 @@ union BuiltinOptions {
SquareOptions,
ZerosLikeOptions,
FillOptions,
+ BidirectionalSequenceLSTMOptions,
+ BidirectionalSequenceRNNOptions,
}
enum Padding : byte { SAME, VALID }
@@ -327,6 +329,7 @@ table SequenceRNNOptions {
table BidirectionalSequenceRNNOptions {
time_major:bool;
fused_activation_function:ActivationFunctionType;
+ merge_outputs: bool;
}
enum FullyConnectedOptionsWeightsFormat: byte {
@@ -391,6 +394,15 @@ table LSTMOptions {
kernel_type: LSTMKernelType = FULL;
}
+table BidirectionalSequenceLSTMOptions {
+ fused_activation_function:ActivationFunctionType;
+ cell_clip: float; // Optional, 0.0 means no clipping
+ proj_clip: float; // Optional, 0.0 means no clipping
+
+ // If true, store the outputs of both directions into the first output.
+ merge_outputs: bool;
+}
+
table ResizeBilinearOptions {
new_height: int (deprecated);
new_width: int (deprecated);
diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h
index 23ac8484de..f3cb113c9c 100755
--- a/tensorflow/contrib/lite/schema/schema_generated.h
+++ b/tensorflow/contrib/lite/schema/schema_generated.h
@@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT;
struct LSTMOptions;
struct LSTMOptionsT;
+struct BidirectionalSequenceLSTMOptions;
+struct BidirectionalSequenceLSTMOptionsT;
+
struct ResizeBilinearOptions;
struct ResizeBilinearOptionsT;
@@ -676,11 +679,13 @@ enum BuiltinOptions {
BuiltinOptions_SquareOptions = 66,
BuiltinOptions_ZerosLikeOptions = 67,
BuiltinOptions_FillOptions = 68,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions = 69,
+ BuiltinOptions_BidirectionalSequenceRNNOptions = 70,
BuiltinOptions_MIN = BuiltinOptions_NONE,
- BuiltinOptions_MAX = BuiltinOptions_FillOptions
+ BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions
};
-inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
+inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] {
static const BuiltinOptions values[] = {
BuiltinOptions_NONE,
BuiltinOptions_Conv2DOptions,
@@ -750,7 +755,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] {
BuiltinOptions_FloorDivOptions,
BuiltinOptions_SquareOptions,
BuiltinOptions_ZerosLikeOptions,
- BuiltinOptions_FillOptions
+ BuiltinOptions_FillOptions,
+ BuiltinOptions_BidirectionalSequenceLSTMOptions,
+ BuiltinOptions_BidirectionalSequenceRNNOptions
};
return values;
}
@@ -826,6 +833,8 @@ inline const char * const *EnumNamesBuiltinOptions() {
"SquareOptions",
"ZerosLikeOptions",
"FillOptions",
+ "BidirectionalSequenceLSTMOptions",
+ "BidirectionalSequenceRNNOptions",
nullptr
};
return names;
@@ -1112,6 +1121,14 @@ template<> struct BuiltinOptionsTraits<FillOptions> {
static const BuiltinOptions enum_value = BuiltinOptions_FillOptions;
};
+template<> struct BuiltinOptionsTraits<BidirectionalSequenceLSTMOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions;
+};
+
+template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> {
+ static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions;
+};
+
struct BuiltinOptionsUnion {
BuiltinOptions type;
void *value;
@@ -1687,6 +1704,22 @@ struct BuiltinOptionsUnion {
return type == BuiltinOptions_FillOptions ?
reinterpret_cast<const FillOptionsT *>(value) : nullptr;
}
+ BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() {
+ return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ const BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const {
+ return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ?
+ reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value) : nullptr;
+ }
+ BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() {
+ return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
+ reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
+ }
+ const BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const {
+ return type == BuiltinOptions_BidirectionalSequenceRNNOptions ?
+ reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr;
+ }
};
bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type);
@@ -2834,9 +2867,11 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable {
typedef BidirectionalSequenceRNNOptions TableType;
bool time_major;
ActivationFunctionType fused_activation_function;
+ bool merge_outputs;
BidirectionalSequenceRNNOptionsT()
: time_major(false),
- fused_activation_function(ActivationFunctionType_NONE) {
+ fused_activation_function(ActivationFunctionType_NONE),
+ merge_outputs(false) {
}
};
@@ -2844,7 +2879,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
typedef BidirectionalSequenceRNNOptionsT NativeTableType;
enum {
VT_TIME_MAJOR = 4,
- VT_FUSED_ACTIVATION_FUNCTION = 6
+ VT_FUSED_ACTIVATION_FUNCTION = 6,
+ VT_MERGE_OUTPUTS = 8
};
bool time_major() const {
return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0;
@@ -2852,10 +2888,14 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf
ActivationFunctionType fused_activation_function() const {
return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
}
+ bool merge_outputs() const {
+ return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
+ }
bool Verify(flatbuffers::Verifier &verifier) const {
return VerifyTableStart(verifier) &&
VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) &&
VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
verifier.EndTable();
}
BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
@@ -2872,6 +2912,9 @@ struct BidirectionalSequenceRNNOptionsBuilder {
void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
fbb_.AddElement<int8_t>(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
}
+ void add_merge_outputs(bool merge_outputs) {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
+ }
explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
: fbb_(_fbb) {
start_ = fbb_.StartTable();
@@ -2887,8 +2930,10 @@ struct BidirectionalSequenceRNNOptionsBuilder {
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequenceRNNOptions(
flatbuffers::FlatBufferBuilder &_fbb,
bool time_major = false,
- ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) {
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ bool merge_outputs = false) {
BidirectionalSequenceRNNOptionsBuilder builder_(_fbb);
+ builder_.add_merge_outputs(merge_outputs);
builder_.add_fused_activation_function(fused_activation_function);
builder_.add_time_major(time_major);
return builder_.Finish();
@@ -3424,6 +3469,96 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(
flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable {
+ typedef BidirectionalSequenceLSTMOptions TableType;
+ ActivationFunctionType fused_activation_function;
+ float cell_clip;
+ float proj_clip;
+ bool merge_outputs;
+ BidirectionalSequenceLSTMOptionsT()
+ : fused_activation_function(ActivationFunctionType_NONE),
+ cell_clip(0.0f),
+ proj_clip(0.0f),
+ merge_outputs(false) {
+ }
+};
+
+struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
+ typedef BidirectionalSequenceLSTMOptionsT NativeTableType;
+ enum {
+ VT_FUSED_ACTIVATION_FUNCTION = 4,
+ VT_CELL_CLIP = 6,
+ VT_PROJ_CLIP = 8,
+ VT_MERGE_OUTPUTS = 10
+ };
+ ActivationFunctionType fused_activation_function() const {
+ return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0));
+ }
+ float cell_clip() const {
+ return GetField<float>(VT_CELL_CLIP, 0.0f);
+ }
+ float proj_clip() const {
+ return GetField<float>(VT_PROJ_CLIP, 0.0f);
+ }
+ bool merge_outputs() const {
+ return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0;
+ }
+ bool Verify(flatbuffers::Verifier &verifier) const {
+ return VerifyTableStart(verifier) &&
+ VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) &&
+ VerifyField<float>(verifier, VT_CELL_CLIP) &&
+ VerifyField<float>(verifier, VT_PROJ_CLIP) &&
+ VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) &&
+ verifier.EndTable();
+ }
+ BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ void UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const;
+ static flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+};
+
+struct BidirectionalSequenceLSTMOptionsBuilder {
+ flatbuffers::FlatBufferBuilder &fbb_;
+ flatbuffers::uoffset_t start_;
+ void add_fused_activation_function(ActivationFunctionType fused_activation_function) {
+ fbb_.AddElement<int8_t>(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0);
+ }
+ void add_cell_clip(float cell_clip) {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f);
+ }
+ void add_proj_clip(float proj_clip) {
+ fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f);
+ }
+ void add_merge_outputs(bool merge_outputs) {
+ fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0);
+ }
+ explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb)
+ : fbb_(_fbb) {
+ start_ = fbb_.StartTable();
+ }
+ BidirectionalSequenceLSTMOptionsBuilder &operator=(const BidirectionalSequenceLSTMOptionsBuilder &);
+ flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Finish() {
+ const auto end = fbb_.EndTable(start_);
+ auto o = flatbuffers::Offset<BidirectionalSequenceLSTMOptions>(end);
+ return o;
+ }
+};
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(
+ flatbuffers::FlatBufferBuilder &_fbb,
+ ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE,
+ float cell_clip = 0.0f,
+ float proj_clip = 0.0f,
+ bool merge_outputs = false) {
+ BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb);
+ builder_.add_proj_clip(proj_clip);
+ builder_.add_cell_clip(cell_clip);
+ builder_.add_merge_outputs(merge_outputs);
+ builder_.add_fused_activation_function(fused_activation_function);
+ return builder_.Finish();
+}
+
+flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr);
+
struct ResizeBilinearOptionsT : public flatbuffers::NativeTable {
typedef ResizeBilinearOptions TableType;
bool align_corners;
@@ -6347,6 +6482,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table {
const FillOptions *builtin_options_as_FillOptions() const {
return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr;
}
+ const BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const {
+ return builtin_options_type() == BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast<const BidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr;
+ }
+ const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const {
+ return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr;
+ }
const flatbuffers::Vector<uint8_t> *custom_options() const {
return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS);
}
@@ -6650,6 +6791,14 @@ template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>()
return builtin_options_as_FillOptions();
}
+template<> inline const BidirectionalSequenceLSTMOptions *Operator::builtin_options_as<BidirectionalSequenceLSTMOptions>() const {
+ return builtin_options_as_BidirectionalSequenceLSTMOptions();
+}
+
+template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_options_as<BidirectionalSequenceRNNOptions>() const {
+ return builtin_options_as_BidirectionalSequenceRNNOptions();
+}
+
struct OperatorBuilder {
flatbuffers::FlatBufferBuilder &fbb_;
flatbuffers::uoffset_t start_;
@@ -7407,6 +7556,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp
(void)_resolver;
{ auto _e = time_major(); _o->time_major = _e; };
{ auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = merge_outputs(); _o->merge_outputs = _e; };
}
inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
@@ -7419,10 +7569,12 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS
struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
auto _time_major = _o->time_major;
auto _fused_activation_function = _o->fused_activation_function;
+ auto _merge_outputs = _o->merge_outputs;
return tflite::CreateBidirectionalSequenceRNNOptions(
_fbb,
_time_major,
- _fused_activation_function);
+ _fused_activation_function,
+ _merge_outputs);
}
inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
@@ -7657,6 +7809,41 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe
_kernel_type);
}
+inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
+ auto _o = new BidirectionalSequenceLSTMOptionsT();
+ UnPackTo(_o, _resolver);
+ return _o;
+}
+
+inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const {
+ (void)_o;
+ (void)_resolver;
+ { auto _e = fused_activation_function(); _o->fused_activation_function = _e; };
+ { auto _e = cell_clip(); _o->cell_clip = _e; };
+ { auto _e = proj_clip(); _o->proj_clip = _e; };
+ { auto _e = merge_outputs(); _o->merge_outputs = _e; };
+}
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) {
+ return CreateBidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher);
+}
+
+inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) {
+ (void)_rehasher;
+ (void)_o;
+ struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va;
+ auto _fused_activation_function = _o->fused_activation_function;
+ auto _cell_clip = _o->cell_clip;
+ auto _proj_clip = _o->proj_clip;
+ auto _merge_outputs = _o->merge_outputs;
+ return tflite::CreateBidirectionalSequenceLSTMOptions(
+ _fbb,
+ _fused_activation_function,
+ _cell_clip,
+ _proj_clip,
+ _merge_outputs);
+}
+
inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const {
auto _o = new ResizeBilinearOptionsT();
UnPackTo(_o, _resolver);
@@ -9425,6 +9612,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob
auto ptr = reinterpret_cast<const FillOptions *>(obj);
return verifier.VerifyTable(ptr);
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
+ return verifier.VerifyTable(ptr);
+ }
default: return false;
}
}
@@ -9715,6 +9910,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c
auto ptr = reinterpret_cast<const FillOptions *>(obj);
return ptr->UnPack(resolver);
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj);
+ return ptr->UnPack(resolver);
+ }
default: return nullptr;
}
}
@@ -9993,6 +10196,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff
auto ptr = reinterpret_cast<const FillOptionsT *>(value);
return CreateFillOptions(_fbb, ptr, _rehasher).Union();
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value);
+ return CreateBidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union();
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value);
+ return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union();
+ }
default: return 0;
}
}
@@ -10271,6 +10482,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL
value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value));
break;
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ value = new BidirectionalSequenceLSTMOptionsT(*reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(u.value));
+ break;
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value));
+ break;
+ }
default:
break;
}
@@ -10618,6 +10837,16 @@ inline void BuiltinOptionsUnion::Reset() {
delete ptr;
break;
}
+ case BuiltinOptions_BidirectionalSequenceLSTMOptions: {
+ auto ptr = reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value);
+ delete ptr;
+ break;
+ }
+ case BuiltinOptions_BidirectionalSequenceRNNOptions: {
+ auto ptr = reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value);
+ delete ptr;
+ break;
+ }
default: break;
}
value = nullptr;