diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-03 13:25:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-03 13:32:42 -0700 |
commit | c2c8cfe22492cf7fab804d32283b623632270035 (patch) | |
tree | 6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib | |
parent | 7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff) |
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive.
PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib')
9 files changed, 640 insertions, 110 deletions
diff --git a/tensorflow/contrib/lite/c/builtin_op_data.h b/tensorflow/contrib/lite/c/builtin_op_data.h index be9d551ee4..44daf7adaa 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data.h +++ b/tensorflow/contrib/lite/c/builtin_op_data.h @@ -99,6 +99,12 @@ typedef struct { TfLiteFusedActivation activation; } TfLiteSequenceRNNParams; +typedef struct { + bool time_major; + TfLiteFusedActivation activation; + bool merge_outputs; +} TfLiteBidirectionalSequenceRNNParams; + typedef enum { kTfLiteFullyConnectedWeightsFormatDefault = 0, kTfLiteFullyConnectedWeightsFormatShuffled4x16Int8 = 1, @@ -181,6 +187,16 @@ typedef struct { } TfLiteLSTMParams; typedef struct { + // Parameters for the LSTM kernel. + TfLiteFusedActivation activation; + float cell_clip; + float proj_clip; + + // If true, store the outputs of both directions in the first output. + bool merge_outputs; +} TfLiteBidirectionalSequenceLSTMParams; + +typedef struct { bool align_corners; } TfLiteResizeBilinearParams; diff --git a/tensorflow/contrib/lite/c/builtin_op_data_test.cc b/tensorflow/contrib/lite/c/builtin_op_data_test.cc index 4d0ba75e68..ba458b4252 100644 --- a/tensorflow/contrib/lite/c/builtin_op_data_test.cc +++ b/tensorflow/contrib/lite/c/builtin_op_data_test.cc @@ -73,6 +73,8 @@ TEST(IntArray, CanCompileStructs) { TfLiteFakeQuantParams fake_quant_params; TfLitePackParams pack_params; TfLiteOneHotParams one_hot_params; + TfLiteBidirectionalSequenceRNNParams bidi_sequence_rnn_params; + TfLiteBidirectionalSequenceLSTMParams bidi_sequence_lstm_params; } } // namespace tflite diff --git a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc index e6900e0950..eac7db9a88 100644 --- a/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/contrib/lite/core/api/flatbuffer_conversions.cc @@ -224,10 +224,8 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - TfLiteSequenceRNNParams* params = - allocator->AllocatePOD<TfLiteSequenceRNNParams>(); + auto params = allocator->AllocatePOD<TfLiteSequenceRNNParams>(); if (auto* sequence_rnn_params = op->builtin_options_as_SequenceRNNOptions()) { params->activation = @@ -237,6 +235,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + allocator->AllocatePOD<TfLiteBidirectionalSequenceRNNParams>(); + if (auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = parse_activation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_RNN: { TfLiteRNNParams* params = allocator->AllocatePOD<TfLiteRNNParams>(); if (auto* rnn_params = op->builtin_options_as_RNNOptions()) { @@ -360,10 +371,9 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: case BuiltinOperator_LSTM: { - TfLiteLSTMParams* params = allocator->AllocatePOD<TfLiteLSTMParams>(); + auto params = allocator->AllocatePOD<TfLiteLSTMParams>(); if (auto* lstm_params = op->builtin_options_as_LSTMOptions()) { params->activation = parse_activation(lstm_params->fused_activation_function()); @@ -381,6 +391,20 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, *builtin_data = reinterpret_cast<void*>(params); break; } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + allocator->AllocatePOD<TfLiteBidirectionalSequenceLSTMParams>(); + if (auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + parse_activation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + } + *builtin_data = reinterpret_cast<void*>(params); + break; + } case BuiltinOperator_RESIZE_BILINEAR: { auto* params = allocator->AllocatePOD<TfLiteResizeBilinearParams>(); if (auto* schema_params = diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc index 66b947771c..0532528f52 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm.cc @@ -119,7 +119,7 @@ constexpr int kBwAuxInputToOutputWeightsTensor = 47; // Optional // Output tensors. constexpr int kFwOutputTensor = 0; -constexpr int kBwOutputTensor = 1; +constexpr int kBwOutputTensor = 1; // Ignored if merge_outputs is set. // Temporary tensors. enum TemporaryTensor { @@ -162,7 +162,8 @@ TfLiteStatus CheckLstmTensorDimensions( int input_gate_bias_tensor, int forget_gate_bias_tensor, int cell_gate_bias_tensor, int output_gate_bias_tensor, int projection_weights_tensor, int projection_bias_tensor) { - const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( + node->builtin_data); // Making sure clipping parameters have valid values. // == 0 means no clipping @@ -347,10 +348,13 @@ TfLiteStatus CheckInputTensorDimensions(TfLiteContext* context, // tensors. Also check that the size of the input tensors match each other. TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data); + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( + node->builtin_data); // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 48); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, + params->merge_outputs ? 1 : 2); // Inferring batch size, number of outputs and sequence length and // number of cells from the input tensors. @@ -368,6 +372,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, fw_input_to_output_weights->dims->data[1], n_input); + const TfLiteTensor* bw_input_to_output_weights = + GetInput(context, node, kBwInputToOutputWeightsTensor); + const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; + TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], + n_input); + const TfLiteTensor* fw_recurrent_to_output_weights = GetInput(context, node, kFwRecurrentToOutputWeightsTensor); TF_LITE_ENSURE_EQ(context, fw_recurrent_to_output_weights->dims->size, 2); @@ -375,6 +386,13 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { n_fw_cell); const int n_fw_output = fw_recurrent_to_output_weights->dims->data[1]; + const TfLiteTensor* bw_recurrent_to_output_weights = + GetInput(context, node, kBwRecurrentToOutputWeightsTensor); + TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); + TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], + n_bw_cell); + const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; + // Check that input tensor dimensions matches with each other. TF_LITE_ENSURE_OK( context, CheckInputTensorDimensions(context, node, n_input, n_fw_output, @@ -440,7 +458,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TfLiteIntArray* fw_output_size = TfLiteIntArrayCreate(3); fw_output_size->data[0] = max_time; fw_output_size->data[1] = n_batch; - fw_output_size->data[2] = n_fw_output; + fw_output_size->data[2] = + params->merge_outputs ? n_bw_output + n_fw_output : n_fw_output; TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_output, fw_output_size)); @@ -479,39 +498,28 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, fw_scratch_buffer, fw_scratch_buffer_size)); // Same for the backward cell. - const TfLiteTensor* bw_input_to_output_weights = - GetInput(context, node, kBwInputToOutputWeightsTensor); - const int n_bw_cell = bw_input_to_output_weights->dims->data[0]; - TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, bw_input_to_output_weights->dims->data[1], - n_input); - - const TfLiteTensor* bw_recurrent_to_output_weights = - GetInput(context, node, kBwRecurrentToOutputWeightsTensor); - TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->size, 2); - TF_LITE_ENSURE_EQ(context, bw_recurrent_to_output_weights->dims->data[0], - n_bw_cell); - const int n_bw_output = bw_recurrent_to_output_weights->dims->data[1]; // Check that input tensor dimensions matches with each other. TF_LITE_ENSURE_OK( context, CheckInputTensorDimensions(context, node, n_input, n_bw_output, n_bw_cell)); - // Get the pointer to output, activation_state and cell_state buffer tensors. - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + // Get the pointer to activation_state and cell_state buffer tensors. TfLiteTensor* bw_activation_state = GetVariableInput(context, node, kBwInputActivationStateTensor); TfLiteTensor* bw_cell_state = GetVariableInput(context, node, kBwInputCellStateTensor); // Resize the output tensors. - TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); - bw_output_size->data[0] = max_time; - bw_output_size->data[1] = n_batch; - bw_output_size->data[2] = n_bw_output; - TF_LITE_ENSURE_OK(context, - context->ResizeTensor(context, bw_output, bw_output_size)); + if (!params->merge_outputs) { + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteIntArray* bw_output_size = TfLiteIntArrayCreate(3); + bw_output_size->data[0] = max_time; + bw_output_size->data[1] = n_batch; + bw_output_size->data[2] = n_bw_output; + TF_LITE_ENSURE_OK( + context, context->ResizeTensor(context, bw_output, bw_output_size)); + } // Check the shape of input state tensors. // These tensor may be 1D or 2D. It's fine as long as the total size is @@ -705,7 +713,7 @@ TfLiteStatus EvalFloat( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* activation_state, TfLiteTensor* cell_state, TfLiteTensor* output) { const int max_time = input->dims->data[0]; @@ -771,12 +779,13 @@ TfLiteStatus EvalFloat( // Loop through the sequence. const int input_step = n_batch * n_input; - const int output_step = n_batch * n_output; + const int output_step = n_batch * output->dims->data[2]; for (int t = 0; t < max_time; t++) { // If this is the forward_sequence, step forward, otherwise step backwards. const int t_rel = forward_sequence ? t : max_time - t - 1; const float* input_ptr = input->data.f + t_rel * input_step; - float* output_ptr_time = output->data.f + t_rel * output_step; + float* output_ptr_time = + output->data.f + t_rel * output_step + output_offset; kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_forget_weights->data.f, @@ -816,7 +825,7 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* input_gate_bias, const TfLiteTensor* forget_gate_bias, const TfLiteTensor* cell_bias, const TfLiteTensor* output_gate_bias, const TfLiteTensor* projection_weights, const TfLiteTensor* projection_bias, - const TfLiteLSTMParams* params, bool forward_sequence, + const TfLiteLSTMParams* params, bool forward_sequence, int output_offset, TfLiteTensor* scratch_buffer, TfLiteTensor* scaling_factors, TfLiteTensor* prod_scaling_factors, TfLiteTensor* recovered_cell_weights, TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, @@ -972,12 +981,12 @@ TfLiteStatus EvalHybrid( // Feed the sequence into the LSTM step-by-step. const int input_step = n_batch * n_input; - const int output_step = n_batch * n_output; + const int output_step = n_batch * output->dims->data[2]; for (int t = 0; t < max_time; t++) { // If this is the forward_sequence, step forward, otherwise step backwards. const int t_rel = forward_sequence ? t : max_time - t - 1; const float* input_ptr = input->data.f + t_rel * input_step; - float* output_ptr = output->data.f + t_rel * output_step; + float* output_ptr = output->data.f + t_rel * output_step + output_offset; kernel_utils::LstmStepWithAuxInput( input_ptr, input_to_input_weights_ptr, input_to_input_weights_scale, @@ -1011,7 +1020,8 @@ TfLiteStatus EvalHybrid( // The LSTM Op engine. TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const auto* params = reinterpret_cast<TfLiteLSTMParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceLSTMParams*>( + node->builtin_data); // Input tensor. const TfLiteTensor* input = GetInput(context, node, kInputTensor); @@ -1107,7 +1117,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwInputActivationStateTensor); TfLiteTensor* bw_cell_state = GetVariableInput(context, node, kBwInputCellStateTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteTensor* bw_output = params->merge_outputs + ? nullptr + : GetOutput(context, node, kBwOutputTensor); // Temporary tensors. TfLiteTensor* fw_scratch_buffer = @@ -1135,6 +1147,14 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { const TfLiteTensor* bw_aux_input_to_output_weights = GetOptionalInputTensor(context, node, kBwAuxInputToOutputWeightsTensor); + // Populate a TfLiteLSTMParams struct for the evaluation functions. + TfLiteLSTMParams lstm_params = {params->activation, params->cell_clip, + params->proj_clip, kTfLiteLSTMFullKernel}; + + const int bw_output_offset = + params->merge_outputs ? fw_recurrent_to_output_weights->dims->data[1] : 0; + const auto actual_bw_output = params->merge_outputs ? fw_output : bw_output; + switch (fw_input_to_output_weights->type) { case kTfLiteFloat32: { TfLiteStatus fw_pass_status = EvalFloat( @@ -1147,9 +1167,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, - fw_projection_weights, fw_projection_bias, params, - /*forward_sequence=*/true, fw_scratch_buffer, fw_activation_state, - fw_cell_state, fw_output); + fw_projection_weights, fw_projection_bias, &lstm_params, + /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, + fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); TfLiteStatus bw_pass_status = EvalFloat( @@ -1162,9 +1182,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_to_forget_weights, bw_aux_input_to_cell_weights, bw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, - bw_projection_weights, bw_projection_bias, params, - /*forward_sequence=*/false, bw_scratch_buffer, bw_activation_state, - bw_cell_state, bw_output); + bw_projection_weights, bw_projection_bias, &lstm_params, + /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, + bw_activation_state, bw_cell_state, actual_bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } @@ -1198,10 +1218,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, fw_input_gate_bias, fw_forget_gate_bias, fw_cell_bias, fw_output_gate_bias, - fw_projection_weights, fw_projection_bias, params, - /*forward_sequence=*/true, fw_scratch_buffer, scaling_factors, - prod_scaling_factors, recovered_cell_weights, input_quantized, - aux_input_quantized, fw_activation_state_quantized, + fw_projection_weights, fw_projection_bias, &lstm_params, + /*forward_sequence=*/true, /*output_offset=*/0, fw_scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, aux_input_quantized, fw_activation_state_quantized, fw_cell_state_quantized, fw_activation_state, fw_cell_state, fw_output); TF_LITE_ENSURE_OK(context, fw_pass_status); @@ -1216,12 +1236,12 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { fw_aux_input_to_forget_weights, fw_aux_input_to_cell_weights, fw_aux_input_to_output_weights, bw_input_gate_bias, bw_forget_gate_bias, bw_cell_bias, bw_output_gate_bias, - bw_projection_weights, bw_projection_bias, params, - /*forward_sequence=*/false, bw_scratch_buffer, scaling_factors, - prod_scaling_factors, recovered_cell_weights, input_quantized, - aux_input_quantized, bw_activation_state_quantized, + bw_projection_weights, bw_projection_bias, &lstm_params, + /*forward_sequence=*/false, bw_output_offset, bw_scratch_buffer, + scaling_factors, prod_scaling_factors, recovered_cell_weights, + input_quantized, aux_input_quantized, bw_activation_state_quantized, bw_cell_state_quantized, bw_activation_state, bw_cell_state, - bw_output); + actual_bw_output); TF_LITE_ENSURE_OK(context, bw_pass_status); return kTfLiteOk; } diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc index 74ba8021c2..9cc04907e1 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_lstm_test.cc @@ -35,8 +35,8 @@ class BidirectionalLSTMOpModel : public SingleOpModel { BidirectionalLSTMOpModel(int n_batch, int n_input, int n_cell, int n_output, int sequence_length, bool use_cifg, bool use_peephole, bool use_projection_weights, - bool use_projection_bias, float cell_clip, - float proj_clip, + bool use_projection_bias, bool merge_outputs, + float cell_clip, float proj_clip, const std::vector<std::vector<int>>& input_shapes) : n_batch_(n_batch), n_input_(n_input), @@ -175,7 +175,9 @@ class BidirectionalLSTMOpModel : public SingleOpModel { fw_output_ = AddOutput(TensorType_FLOAT32); - bw_output_ = AddOutput(TensorType_FLOAT32); + if (!merge_outputs) { + bw_output_ = AddOutput(TensorType_FLOAT32); + } aux_input_ = AddNullInput(); fw_aux_input_to_input_weights_ = AddNullInput(); @@ -188,9 +190,10 @@ class BidirectionalLSTMOpModel : public SingleOpModel { bw_aux_input_to_output_weights_ = AddNullInput(); SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM, - BuiltinOptions_LSTMOptions, - CreateLSTMOptions(builder_, ActivationFunctionType_TANH, - cell_clip, proj_clip) + BuiltinOptions_BidirectionalSequenceLSTMOptions, + CreateBidirectionalSequenceLSTMOptions( + builder_, ActivationFunctionType_TANH, cell_clip, + proj_clip, merge_outputs) .Union()); BuildInterpreter(input_shapes); } @@ -380,7 +383,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -526,6 +530,162 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClipping) { ElementsAreArray(ArrayFloatNear(bw_expected))); } +// Same as the previous test, yet with a single merged output tensor. +TEST(LSTMOpTest, BlackBoxTestMergedOutput) { + const int n_batch = 1; + const int n_input = 2; + // n_cell and n_output have the same size when there is no projection. + const int n_cell = 4; + const int n_output = 4; + const int sequence_length = 3; + + BidirectionalLSTMOpModel lstm( + n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, + /*use_peephole=*/false, /*use_projection_weights=*/false, + /*use_projection_bias=*/false, /*merge_outputs=*/true, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, + { + {sequence_length, n_batch, n_input}, // input tensor + + // Forward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + // Backward cell + {n_cell, n_input}, // input_to_input_weight tensor + {n_cell, n_input}, // input_to_forget_weight tensor + {n_cell, n_input}, // input_to_cell_weight tensor + {n_cell, n_input}, // input_to_output_weight tensor + + {n_cell, n_output}, // recurrent_to_input_weight tensor + {n_cell, n_output}, // recurrent_to_forget_weight tensor + {n_cell, n_output}, // recurrent_to_cell_weight tensor + {n_cell, n_output}, // recurrent_to_output_weight tensor + + {0}, // cell_to_input_weight tensor + {0}, // cell_to_forget_weight tensor + {0}, // cell_to_output_weight tensor + + {n_cell}, // input_gate_bias tensor + {n_cell}, // forget_gate_bias tensor + {n_cell}, // cell_bias tensor + {n_cell}, // output_gate_bias tensor + + {0, 0}, // projection_weight tensor + {0}, // projection_bias tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, n_output}, // activation_state tensor + {n_batch, n_cell}, // cell_state tensor + + {n_batch, sequence_length, 0}, // aux_input tensor + {n_cell, 0}, // aux_fw_input_to_input tensor + {n_cell, 0}, // aux_fw_input_to_forget tensor + {n_cell, 0}, // aux_fw_input_to_cell tensor + {n_cell, 0}, // aux_fw_input_to_output tensor + {n_cell, 0}, // aux_bw_input_to_input tensor + {n_cell, 0}, // aux_bw_input_to_forget tensor + {n_cell, 0}, // aux_bw_input_to_cell tensor + {n_cell, 0}, // aux_bw_input_to_output tensor + }); + + lstm.SetInputToInputWeights({-0.45018822, -0.02338299, -0.0870589, + -0.34550029, 0.04266912, -0.15680569, + -0.34856534, 0.43890524}); + + lstm.SetInputToCellWeights({-0.50013041, 0.1370284, 0.11810488, 0.2013163, + -0.20583314, 0.44344562, 0.22077113, + -0.29909778}); + + lstm.SetInputToForgetWeights({0.09701663, 0.20334584, -0.50592935, + -0.31343272, -0.40032279, 0.44781327, + 0.01387155, -0.35593212}); + + lstm.SetInputToOutputWeights({-0.25065863, -0.28290087, 0.04613829, + 0.40525138, 0.44272184, 0.03897077, -0.1556896, + 0.19487578}); + + lstm.SetInputGateBias({0., 0., 0., 0.}); + + lstm.SetCellBias({0., 0., 0., 0.}); + + lstm.SetForgetGateBias({1., 1., 1., 1.}); + + lstm.SetOutputGateBias({0., 0., 0., 0.}); + + lstm.SetRecurrentToInputWeights( + {-0.0063535, -0.2042388, 0.31454784, -0.35746509, 0.28902304, 0.08183324, + -0.16555229, 0.02286911, -0.13566875, 0.03034258, 0.48091322, + -0.12528998, 0.24077177, -0.51332325, -0.33502164, 0.10629296}); + + lstm.SetRecurrentToCellWeights( + {-0.3407414, 0.24443203, -0.2078532, 0.26320225, 0.05695659, -0.00123841, + -0.4744786, -0.35869038, -0.06418842, -0.13502428, -0.501764, 0.22830659, + -0.46367589, 0.26016325, -0.03894562, -0.16368064}); + + lstm.SetRecurrentToForgetWeights( + {-0.48684245, -0.06655136, 0.42224967, 0.2112639, 0.27654213, 0.20864892, + -0.07646349, 0.45877004, 0.00141793, -0.14609534, 0.36447752, 0.09196436, + 0.28053468, 0.01560611, -0.20127171, -0.01140004}); + + lstm.SetRecurrentToOutputWeights( + {0.43385774, -0.17194885, 0.2718237, 0.09215671, 0.24107647, -0.39835793, + 0.18212086, 0.01301402, 0.48572797, -0.50656658, 0.20047462, -0.20607421, + -0.51818722, -0.15390486, 0.0468148, 0.39922136}); + + // Input should have n_input * sequence_length many values. + static float lstm_input[] = {2., 3., 3., 4., 1., 1.}; + static float lstm_fw_golden_output[] = { + -0.02973187, 0.1229473, 0.20885126, -0.15358765, + -0.03716109, 0.12507336, 0.41193449, -0.20860538, + -0.15053082, 0.09120187, 0.24278517, -0.12222792}; + static float lstm_bw_golden_output[] = { + -0.0806187, 0.139077, 0.400476, -0.197842, -0.0332076, 0.123838, + 0.309777, -0.17621, -0.0490733, 0.0739237, 0.067706, -0.0208124}; + + float* batch0_start = lstm_input; + float* batch0_end = batch0_start + lstm.num_inputs() * lstm.sequence_length(); + + lstm.SetInput(0, batch0_start, batch0_end); + + lstm.Invoke(); + + std::vector<float> merged_expected; + for (int k = 0; k < lstm.sequence_length(); k++) { + merged_expected.insert( + merged_expected.end(), + lstm_fw_golden_output + k * lstm.num_fw_outputs(), + lstm_fw_golden_output + (k + 1) * lstm.num_fw_outputs()); + merged_expected.insert( + merged_expected.end(), + lstm_bw_golden_output + k * lstm.num_bw_outputs(), + lstm_bw_golden_output + (k + 1) * lstm.num_bw_outputs()); + } + EXPECT_THAT(lstm.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(merged_expected))); +} + TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { const int n_batch = 1; const int n_input = 2; @@ -537,7 +697,8 @@ TEST(LSTMOpTest, BlackBoxTestNoCifgNoPeepholeNoProjectionNoClippingReverse) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/false, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -696,7 +857,8 @@ TEST(LSTMOpTest, BlackBoxTestWithCifgWithPeepholeNoProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -845,7 +1007,8 @@ TEST(LSTMOpTest, BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/true, /*use_peephole=*/true, /*use_projection_weights=*/false, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor @@ -994,7 +1157,8 @@ TEST(LSTMOpTest, BlackBoxTestWithPeepholeWithProjectionNoClipping) { BidirectionalLSTMOpModel lstm( n_batch, n_input, n_cell, n_output, sequence_length, /*use_cifg=*/false, /*use_peephole=*/true, /*use_projection_weights=*/true, - /*use_projection_bias=*/false, /*cell_clip=*/0.0, /*proj_clip=*/0.0, + /*use_projection_bias=*/false, /*merge_outputs=*/false, /*cell_clip=*/0.0, + /*proj_clip=*/0.0, { {sequence_length, n_batch, n_input}, // input tensor diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index 2f896c5289..9f62ac3f2c 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -47,7 +47,7 @@ constexpr int kFwAuxWeightsTensor = 10; // Optional. constexpr int kBwAuxWeightsTensor = 11; // Optional. // Output tensors. constexpr int kFwOutputTensor = 0; -constexpr int kBwOutputTensor = 1; +constexpr int kBwOutputTensor = 1; // Only if merge_outputs is false. // Temporary tensors. enum TemporaryTensor { @@ -70,9 +70,13 @@ void Free(TfLiteContext* context, void* buffer) { } TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( + node->builtin_data); + // Check we have all the inputs and outputs we need. TF_LITE_ENSURE_EQ(context, node->inputs->size, 12); - TF_LITE_ENSURE_EQ(context, node->outputs->size, 2); + TF_LITE_ENSURE_EQ(context, node->outputs->size, + params->merge_outputs ? 1 : 2); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* fw_input_weights = @@ -142,9 +146,6 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { bw_aux_input_weights->dims->data[1]); } - TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); - const bool is_hybrid_op = (fw_input_weights->type == kTfLiteUInt8 && input->type == kTfLiteFloat32); @@ -233,18 +234,23 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } // Resize outputs. + TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); fw_output_size_array->data[0] = batch_size; fw_output_size_array->data[1] = max_time; - fw_output_size_array->data[2] = fw_num_units; + fw_output_size_array->data[2] = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; TF_LITE_ENSURE_OK( context, context->ResizeTensor(context, fw_output, fw_output_size_array)); - TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); - bw_output_size_array->data[0] = batch_size; - bw_output_size_array->data[1] = max_time; - bw_output_size_array->data[2] = bw_num_units; - TF_LITE_ENSURE_OK( - context, context->ResizeTensor(context, bw_output, bw_output_size_array)); + if (!params->merge_outputs) { + TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteIntArray* bw_output_size_array = TfLiteIntArrayCreate(3); + bw_output_size_array->data[0] = batch_size; + bw_output_size_array->data[1] = max_time; + bw_output_size_array->data[2] = bw_num_units; + TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, bw_output, + bw_output_size_array)); + } return kTfLiteOk; } @@ -256,9 +262,9 @@ TfLiteStatus EvalFloat( const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, const TfLiteTensor* aux_input, const TfLiteTensor* fw_aux_input_weights, const TfLiteTensor* bw_aux_input_weights, - const TfLiteSequenceRNNParams* params, TfLiteTensor* fw_hidden_state, - TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state, - TfLiteTensor* bw_output) { + const TfLiteBidirectionalSequenceRNNParams* params, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int input_size = input->dims->data[2]; @@ -281,10 +287,15 @@ TfLiteStatus EvalFloat( ? bw_aux_input_weights->data.f : nullptr; + const int fw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; + const int bw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; for (int b = 0; b < batch_size; b++) { // Forward cell. float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -292,8 +303,7 @@ TfLiteStatus EvalFloat( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, @@ -304,6 +314,10 @@ TfLiteStatus EvalFloat( // Backward cell. float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units + : bw_output->data.f + b * bw_output_step * max_time; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -311,8 +325,7 @@ TfLiteStatus EvalFloat( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, @@ -331,11 +344,12 @@ TfLiteStatus EvalHybrid( const TfLiteTensor* bw_recurrent_weights, const TfLiteTensor* bw_bias, const TfLiteTensor* aux_input, const TfLiteTensor* aux_fw_input_weights, const TfLiteTensor* aux_bw_input_weights, - const TfLiteSequenceRNNParams* params, TfLiteTensor* scaling_factors, - TfLiteTensor* input_quantized, TfLiteTensor* aux_input_quantized, - TfLiteTensor* fw_hidden_state_quantized, TfLiteTensor* fw_hidden_state, - TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, - TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { + const TfLiteBidirectionalSequenceRNNParams* params, + TfLiteTensor* scaling_factors, TfLiteTensor* input_quantized, + TfLiteTensor* aux_input_quantized, TfLiteTensor* fw_hidden_state_quantized, + TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, + TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, + TfLiteTensor* bw_output) { const int batch_size = input->dims->data[0]; const int max_time = input->dims->data[1]; const int input_size = input->dims->data[2]; @@ -384,10 +398,15 @@ TfLiteStatus EvalHybrid( reinterpret_cast<int8_t*>(bw_hidden_state_quantized->data.uint8); float* scaling_factors_ptr = scaling_factors->data.f; + const int fw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; + const int bw_output_step = + params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; for (int b = 0; b < batch_size; b++) { // Forward cell. float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -395,8 +414,7 @@ TfLiteStatus EvalHybrid( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - fw_output->data.f + b * fw_num_units * max_time + s * fw_num_units; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, @@ -411,6 +429,10 @@ TfLiteStatus EvalHybrid( // Backward cell. float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + : bw_output->data.f + b * bw_output_step * max_time; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = input->data.f + b * input_size * max_time + s * input_size; @@ -418,8 +440,7 @@ TfLiteStatus EvalHybrid( (aux_input != nullptr) ? aux_input->data.f + b * input_size * max_time + s * input_size : nullptr; - float* output_ptr_batch = - bw_output->data.f + b * bw_num_units * max_time + s * bw_num_units; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, @@ -436,8 +457,8 @@ TfLiteStatus EvalHybrid( } TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { - const auto* params = - reinterpret_cast<TfLiteSequenceRNNParams*>(node->builtin_data); + const auto* params = reinterpret_cast<TfLiteBidirectionalSequenceRNNParams*>( + node->builtin_data); const TfLiteTensor* input = GetInput(context, node, kInputTensor); const TfLiteTensor* fw_input_weights = @@ -465,7 +486,9 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { GetVariableInput(context, node, kBwHiddenStateTensor); TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); - TfLiteTensor* bw_output = GetOutput(context, node, kBwOutputTensor); + TfLiteTensor* bw_output = params->merge_outputs + ? nullptr + : GetOutput(context, node, kBwOutputTensor); switch (fw_input_weights->type) { case kTfLiteFloat32: diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index 3e34ba6196..f555c472f5 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = { class BidirectionalRNNOpModel : public SingleOpModel { public: BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, - int bw_units, int input_size) + int bw_units, int input_size, bool merge_outputs) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), @@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel { aux_bw_weights_ = AddNullInput(); fw_output_ = AddOutput(TensorType_FLOAT32); - bw_output_ = AddOutput(TensorType_FLOAT32); + if (!merge_outputs) { + bw_output_ = AddOutput(TensorType_FLOAT32); + } SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_SequenceRNNOptions, - CreateSequenceRNNOptions(builder_, /*time_major=*/false, - ActivationFunctionType_RELU) + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, /*time_major=*/false, + ActivationFunctionType_RELU, merge_outputs) .Union()); BuildInterpreter({ {batches_, sequence_len_, input_size_}, // input @@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel { TEST(BidirectionalRNNOpTest, BlackBoxTest) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); } +// Same as the previous test, yet with merged outputs. +TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*merge_outputs=*/true); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + std::vector<float> merged_expected; + for (int bid = 0; bid < rnn.num_batches(); bid++) { + for (int step = 0; step < rnn.sequence_len(); step++) { + merged_expected.insert( + merged_expected.end(), + rnn_golden_fw_output + rnn.num_fw_units() * step, + rnn_golden_fw_output + rnn.num_fw_units() * (step + 1)); + merged_expected.insert( + merged_expected.end(), + rnn_golden_bw_output + rnn.num_bw_units() * step, + rnn_golden_bw_output + rnn.num_bw_units() * (step + 1)); + } + } + EXPECT_THAT(rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(merged_expected))); +} + // Check that if the input sequence is reversed the outputs are the same just // forward and backward are swapped (and reversed). TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { TEST(BidirectionalRNNOpTest, EndToEndTest) { BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); const int output_size = 4; float dnn_weights[] = { -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs index 3da3188c3a..ff8430827c 100644 --- a/tensorflow/contrib/lite/schema/schema.fbs +++ b/tensorflow/contrib/lite/schema/schema.fbs @@ -248,6 +248,8 @@ union BuiltinOptions { SquareOptions, ZerosLikeOptions, FillOptions, + BidirectionalSequenceLSTMOptions, + BidirectionalSequenceRNNOptions, } enum Padding : byte { SAME, VALID } @@ -327,6 +329,7 @@ table SequenceRNNOptions { table BidirectionalSequenceRNNOptions { time_major:bool; fused_activation_function:ActivationFunctionType; + merge_outputs: bool; } enum FullyConnectedOptionsWeightsFormat: byte { @@ -391,6 +394,15 @@ table LSTMOptions { kernel_type: LSTMKernelType = FULL; } +table BidirectionalSequenceLSTMOptions { + fused_activation_function:ActivationFunctionType; + cell_clip: float; // Optional, 0.0 means no clipping + proj_clip: float; // Optional, 0.0 means no clipping + + // If true, store the outputs of both directions into the first output. + merge_outputs: bool; +} + table ResizeBilinearOptions { new_height: int (deprecated); new_width: int (deprecated); diff --git a/tensorflow/contrib/lite/schema/schema_generated.h b/tensorflow/contrib/lite/schema/schema_generated.h index 23ac8484de..f3cb113c9c 100755 --- a/tensorflow/contrib/lite/schema/schema_generated.h +++ b/tensorflow/contrib/lite/schema/schema_generated.h @@ -79,6 +79,9 @@ struct LocalResponseNormalizationOptionsT; struct LSTMOptions; struct LSTMOptionsT; +struct BidirectionalSequenceLSTMOptions; +struct BidirectionalSequenceLSTMOptionsT; + struct ResizeBilinearOptions; struct ResizeBilinearOptionsT; @@ -676,11 +679,13 @@ enum BuiltinOptions { BuiltinOptions_SquareOptions = 66, BuiltinOptions_ZerosLikeOptions = 67, BuiltinOptions_FillOptions = 68, + BuiltinOptions_BidirectionalSequenceLSTMOptions = 69, + BuiltinOptions_BidirectionalSequenceRNNOptions = 70, BuiltinOptions_MIN = BuiltinOptions_NONE, - BuiltinOptions_MAX = BuiltinOptions_FillOptions + BuiltinOptions_MAX = BuiltinOptions_BidirectionalSequenceRNNOptions }; -inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] { +inline const BuiltinOptions (&EnumValuesBuiltinOptions())[71] { static const BuiltinOptions values[] = { BuiltinOptions_NONE, BuiltinOptions_Conv2DOptions, @@ -750,7 +755,9 @@ inline const BuiltinOptions (&EnumValuesBuiltinOptions())[69] { BuiltinOptions_FloorDivOptions, BuiltinOptions_SquareOptions, BuiltinOptions_ZerosLikeOptions, - BuiltinOptions_FillOptions + BuiltinOptions_FillOptions, + BuiltinOptions_BidirectionalSequenceLSTMOptions, + BuiltinOptions_BidirectionalSequenceRNNOptions }; return values; } @@ -826,6 +833,8 @@ inline const char * const *EnumNamesBuiltinOptions() { "SquareOptions", "ZerosLikeOptions", "FillOptions", + "BidirectionalSequenceLSTMOptions", + "BidirectionalSequenceRNNOptions", nullptr }; return names; @@ -1112,6 +1121,14 @@ template<> struct BuiltinOptionsTraits<FillOptions> { static const BuiltinOptions enum_value = BuiltinOptions_FillOptions; }; +template<> struct BuiltinOptionsTraits<BidirectionalSequenceLSTMOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceLSTMOptions; +}; + +template<> struct BuiltinOptionsTraits<BidirectionalSequenceRNNOptions> { + static const BuiltinOptions enum_value = BuiltinOptions_BidirectionalSequenceRNNOptions; +}; + struct BuiltinOptionsUnion { BuiltinOptions type; void *value; @@ -1687,6 +1704,22 @@ struct BuiltinOptionsUnion { return type == BuiltinOptions_FillOptions ? reinterpret_cast<const FillOptionsT *>(value) : nullptr; } + BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() { + return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? + reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value) : nullptr; + } + const BidirectionalSequenceLSTMOptionsT *AsBidirectionalSequenceLSTMOptions() const { + return type == BuiltinOptions_BidirectionalSequenceLSTMOptions ? + reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value) : nullptr; + } + BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() { + return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? + reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value) : nullptr; + } + const BidirectionalSequenceRNNOptionsT *AsBidirectionalSequenceRNNOptions() const { + return type == BuiltinOptions_BidirectionalSequenceRNNOptions ? + reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value) : nullptr; + } }; bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *obj, BuiltinOptions type); @@ -2834,9 +2867,11 @@ struct BidirectionalSequenceRNNOptionsT : public flatbuffers::NativeTable { typedef BidirectionalSequenceRNNOptions TableType; bool time_major; ActivationFunctionType fused_activation_function; + bool merge_outputs; BidirectionalSequenceRNNOptionsT() : time_major(false), - fused_activation_function(ActivationFunctionType_NONE) { + fused_activation_function(ActivationFunctionType_NONE), + merge_outputs(false) { } }; @@ -2844,7 +2879,8 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf typedef BidirectionalSequenceRNNOptionsT NativeTableType; enum { VT_TIME_MAJOR = 4, - VT_FUSED_ACTIVATION_FUNCTION = 6 + VT_FUSED_ACTIVATION_FUNCTION = 6, + VT_MERGE_OUTPUTS = 8 }; bool time_major() const { return GetField<uint8_t>(VT_TIME_MAJOR, 0) != 0; @@ -2852,10 +2888,14 @@ struct BidirectionalSequenceRNNOptions FLATBUFFERS_FINAL_CLASS : private flatbuf ActivationFunctionType fused_activation_function() const { return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); } + bool merge_outputs() const { + return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; + } bool Verify(flatbuffers::Verifier &verifier) const { return VerifyTableStart(verifier) && VerifyField<uint8_t>(verifier, VT_TIME_MAJOR) && VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && verifier.EndTable(); } BidirectionalSequenceRNNOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; @@ -2872,6 +2912,9 @@ struct BidirectionalSequenceRNNOptionsBuilder { void add_fused_activation_function(ActivationFunctionType fused_activation_function) { fbb_.AddElement<int8_t>(BidirectionalSequenceRNNOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); } + void add_merge_outputs(bool merge_outputs) { + fbb_.AddElement<uint8_t>(BidirectionalSequenceRNNOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0); + } explicit BidirectionalSequenceRNNOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) : fbb_(_fbb) { start_ = fbb_.StartTable(); @@ -2887,8 +2930,10 @@ struct BidirectionalSequenceRNNOptionsBuilder { inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalSequenceRNNOptions( flatbuffers::FlatBufferBuilder &_fbb, bool time_major = false, - ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE) { + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + bool merge_outputs = false) { BidirectionalSequenceRNNOptionsBuilder builder_(_fbb); + builder_.add_merge_outputs(merge_outputs); builder_.add_fused_activation_function(fused_activation_function); builder_.add_time_major(time_major); return builder_.Finish(); @@ -3424,6 +3469,96 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions( flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const LSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +struct BidirectionalSequenceLSTMOptionsT : public flatbuffers::NativeTable { + typedef BidirectionalSequenceLSTMOptions TableType; + ActivationFunctionType fused_activation_function; + float cell_clip; + float proj_clip; + bool merge_outputs; + BidirectionalSequenceLSTMOptionsT() + : fused_activation_function(ActivationFunctionType_NONE), + cell_clip(0.0f), + proj_clip(0.0f), + merge_outputs(false) { + } +}; + +struct BidirectionalSequenceLSTMOptions FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { + typedef BidirectionalSequenceLSTMOptionsT NativeTableType; + enum { + VT_FUSED_ACTIVATION_FUNCTION = 4, + VT_CELL_CLIP = 6, + VT_PROJ_CLIP = 8, + VT_MERGE_OUTPUTS = 10 + }; + ActivationFunctionType fused_activation_function() const { + return static_cast<ActivationFunctionType>(GetField<int8_t>(VT_FUSED_ACTIVATION_FUNCTION, 0)); + } + float cell_clip() const { + return GetField<float>(VT_CELL_CLIP, 0.0f); + } + float proj_clip() const { + return GetField<float>(VT_PROJ_CLIP, 0.0f); + } + bool merge_outputs() const { + return GetField<uint8_t>(VT_MERGE_OUTPUTS, 0) != 0; + } + bool Verify(flatbuffers::Verifier &verifier) const { + return VerifyTableStart(verifier) && + VerifyField<int8_t>(verifier, VT_FUSED_ACTIVATION_FUNCTION) && + VerifyField<float>(verifier, VT_CELL_CLIP) && + VerifyField<float>(verifier, VT_PROJ_CLIP) && + VerifyField<uint8_t>(verifier, VT_MERGE_OUTPUTS) && + verifier.EndTable(); + } + BidirectionalSequenceLSTMOptionsT *UnPack(const flatbuffers::resolver_function_t *_resolver = nullptr) const; + void UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver = nullptr) const; + static flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); +}; + +struct BidirectionalSequenceLSTMOptionsBuilder { + flatbuffers::FlatBufferBuilder &fbb_; + flatbuffers::uoffset_t start_; + void add_fused_activation_function(ActivationFunctionType fused_activation_function) { + fbb_.AddElement<int8_t>(BidirectionalSequenceLSTMOptions::VT_FUSED_ACTIVATION_FUNCTION, static_cast<int8_t>(fused_activation_function), 0); + } + void add_cell_clip(float cell_clip) { + fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_CELL_CLIP, cell_clip, 0.0f); + } + void add_proj_clip(float proj_clip) { + fbb_.AddElement<float>(BidirectionalSequenceLSTMOptions::VT_PROJ_CLIP, proj_clip, 0.0f); + } + void add_merge_outputs(bool merge_outputs) { + fbb_.AddElement<uint8_t>(BidirectionalSequenceLSTMOptions::VT_MERGE_OUTPUTS, static_cast<uint8_t>(merge_outputs), 0); + } + explicit BidirectionalSequenceLSTMOptionsBuilder(flatbuffers::FlatBufferBuilder &_fbb) + : fbb_(_fbb) { + start_ = fbb_.StartTable(); + } + BidirectionalSequenceLSTMOptionsBuilder &operator=(const BidirectionalSequenceLSTMOptionsBuilder &); + flatbuffers::Offset<BidirectionalSequenceLSTMOptions> Finish() { + const auto end = fbb_.EndTable(start_); + auto o = flatbuffers::Offset<BidirectionalSequenceLSTMOptions>(end); + return o; + } +}; + +inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions( + flatbuffers::FlatBufferBuilder &_fbb, + ActivationFunctionType fused_activation_function = ActivationFunctionType_NONE, + float cell_clip = 0.0f, + float proj_clip = 0.0f, + bool merge_outputs = false) { + BidirectionalSequenceLSTMOptionsBuilder builder_(_fbb); + builder_.add_proj_clip(proj_clip); + builder_.add_cell_clip(cell_clip); + builder_.add_merge_outputs(merge_outputs); + builder_.add_fused_activation_function(fused_activation_function); + return builder_.Finish(); +} + +flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher = nullptr); + struct ResizeBilinearOptionsT : public flatbuffers::NativeTable { typedef ResizeBilinearOptions TableType; bool align_corners; @@ -6347,6 +6482,12 @@ struct Operator FLATBUFFERS_FINAL_CLASS : private flatbuffers::Table { const FillOptions *builtin_options_as_FillOptions() const { return builtin_options_type() == BuiltinOptions_FillOptions ? static_cast<const FillOptions *>(builtin_options()) : nullptr; } + const BidirectionalSequenceLSTMOptions *builtin_options_as_BidirectionalSequenceLSTMOptions() const { + return builtin_options_type() == BuiltinOptions_BidirectionalSequenceLSTMOptions ? static_cast<const BidirectionalSequenceLSTMOptions *>(builtin_options()) : nullptr; + } + const BidirectionalSequenceRNNOptions *builtin_options_as_BidirectionalSequenceRNNOptions() const { + return builtin_options_type() == BuiltinOptions_BidirectionalSequenceRNNOptions ? static_cast<const BidirectionalSequenceRNNOptions *>(builtin_options()) : nullptr; + } const flatbuffers::Vector<uint8_t> *custom_options() const { return GetPointer<const flatbuffers::Vector<uint8_t> *>(VT_CUSTOM_OPTIONS); } @@ -6650,6 +6791,14 @@ template<> inline const FillOptions *Operator::builtin_options_as<FillOptions>() return builtin_options_as_FillOptions(); } +template<> inline const BidirectionalSequenceLSTMOptions *Operator::builtin_options_as<BidirectionalSequenceLSTMOptions>() const { + return builtin_options_as_BidirectionalSequenceLSTMOptions(); +} + +template<> inline const BidirectionalSequenceRNNOptions *Operator::builtin_options_as<BidirectionalSequenceRNNOptions>() const { + return builtin_options_as_BidirectionalSequenceRNNOptions(); +} + struct OperatorBuilder { flatbuffers::FlatBufferBuilder &fbb_; flatbuffers::uoffset_t start_; @@ -7407,6 +7556,7 @@ inline void BidirectionalSequenceRNNOptions::UnPackTo(BidirectionalSequenceRNNOp (void)_resolver; { auto _e = time_major(); _o->time_major = _e; }; { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = merge_outputs(); _o->merge_outputs = _e; }; } inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> BidirectionalSequenceRNNOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceRNNOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { @@ -7419,10 +7569,12 @@ inline flatbuffers::Offset<BidirectionalSequenceRNNOptions> CreateBidirectionalS struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceRNNOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; auto _time_major = _o->time_major; auto _fused_activation_function = _o->fused_activation_function; + auto _merge_outputs = _o->merge_outputs; return tflite::CreateBidirectionalSequenceRNNOptions( _fbb, _time_major, - _fused_activation_function); + _fused_activation_function, + _merge_outputs); } inline FullyConnectedOptionsT *FullyConnectedOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { @@ -7657,6 +7809,41 @@ inline flatbuffers::Offset<LSTMOptions> CreateLSTMOptions(flatbuffers::FlatBuffe _kernel_type); } +inline BidirectionalSequenceLSTMOptionsT *BidirectionalSequenceLSTMOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { + auto _o = new BidirectionalSequenceLSTMOptionsT(); + UnPackTo(_o, _resolver); + return _o; +} + +inline void BidirectionalSequenceLSTMOptions::UnPackTo(BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::resolver_function_t *_resolver) const { + (void)_o; + (void)_resolver; + { auto _e = fused_activation_function(); _o->fused_activation_function = _e; }; + { auto _e = cell_clip(); _o->cell_clip = _e; }; + { auto _e = proj_clip(); _o->proj_clip = _e; }; + { auto _e = merge_outputs(); _o->merge_outputs = _e; }; +} + +inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> BidirectionalSequenceLSTMOptions::Pack(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT* _o, const flatbuffers::rehasher_function_t *_rehasher) { + return CreateBidirectionalSequenceLSTMOptions(_fbb, _o, _rehasher); +} + +inline flatbuffers::Offset<BidirectionalSequenceLSTMOptions> CreateBidirectionalSequenceLSTMOptions(flatbuffers::FlatBufferBuilder &_fbb, const BidirectionalSequenceLSTMOptionsT *_o, const flatbuffers::rehasher_function_t *_rehasher) { + (void)_rehasher; + (void)_o; + struct _VectorArgs { flatbuffers::FlatBufferBuilder *__fbb; const BidirectionalSequenceLSTMOptionsT* __o; const flatbuffers::rehasher_function_t *__rehasher; } _va = { &_fbb, _o, _rehasher}; (void)_va; + auto _fused_activation_function = _o->fused_activation_function; + auto _cell_clip = _o->cell_clip; + auto _proj_clip = _o->proj_clip; + auto _merge_outputs = _o->merge_outputs; + return tflite::CreateBidirectionalSequenceLSTMOptions( + _fbb, + _fused_activation_function, + _cell_clip, + _proj_clip, + _merge_outputs); +} + inline ResizeBilinearOptionsT *ResizeBilinearOptions::UnPack(const flatbuffers::resolver_function_t *_resolver) const { auto _o = new ResizeBilinearOptionsT(); UnPackTo(_o, _resolver); @@ -9425,6 +9612,14 @@ inline bool VerifyBuiltinOptions(flatbuffers::Verifier &verifier, const void *ob auto ptr = reinterpret_cast<const FillOptions *>(obj); return verifier.VerifyTable(ptr); } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj); + return verifier.VerifyTable(ptr); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj); + return verifier.VerifyTable(ptr); + } default: return false; } } @@ -9715,6 +9910,14 @@ inline void *BuiltinOptionsUnion::UnPack(const void *obj, BuiltinOptions type, c auto ptr = reinterpret_cast<const FillOptions *>(obj); return ptr->UnPack(resolver); } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptions *>(obj); + return ptr->UnPack(resolver); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptions *>(obj); + return ptr->UnPack(resolver); + } default: return nullptr; } } @@ -9993,6 +10196,14 @@ inline flatbuffers::Offset<void> BuiltinOptionsUnion::Pack(flatbuffers::FlatBuff auto ptr = reinterpret_cast<const FillOptionsT *>(value); return CreateFillOptions(_fbb, ptr, _rehasher).Union(); } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceLSTMOptionsT *>(value); + return CreateBidirectionalSequenceLSTMOptions(_fbb, ptr, _rehasher).Union(); + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast<const BidirectionalSequenceRNNOptionsT *>(value); + return CreateBidirectionalSequenceRNNOptions(_fbb, ptr, _rehasher).Union(); + } default: return 0; } } @@ -10271,6 +10482,14 @@ inline BuiltinOptionsUnion::BuiltinOptionsUnion(const BuiltinOptionsUnion &u) FL value = new FillOptionsT(*reinterpret_cast<FillOptionsT *>(u.value)); break; } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + value = new BidirectionalSequenceLSTMOptionsT(*reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(u.value)); + break; + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + value = new BidirectionalSequenceRNNOptionsT(*reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(u.value)); + break; + } default: break; } @@ -10618,6 +10837,16 @@ inline void BuiltinOptionsUnion::Reset() { delete ptr; break; } + case BuiltinOptions_BidirectionalSequenceLSTMOptions: { + auto ptr = reinterpret_cast<BidirectionalSequenceLSTMOptionsT *>(value); + delete ptr; + break; + } + case BuiltinOptions_BidirectionalSequenceRNNOptions: { + auto ptr = reinterpret_cast<BidirectionalSequenceRNNOptionsT *>(value); + delete ptr; + break; + } default: break; } value = nullptr; |