From b145f46b735fe1e383be6629cafaa5269b07b7fb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 9 Oct 2018 14:12:25 -0700 Subject: Add support for time-major input in the bidirectional RNN Op. PiperOrigin-RevId: 216419983 --- .../lite/kernels/bidirectional_sequence_rnn.cc | 251 +++++++++++++++------ .../kernels/bidirectional_sequence_rnn_test.cc | 94 ++++++-- 2 files changed, 247 insertions(+), 98 deletions(-) diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc index c22a457a71..f544dd5ffa 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc @@ -114,8 +114,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32); TF_LITE_ENSURE_EQ(context, input->dims->size, 3); - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int fw_num_units = fw_input_weights->dims->data[0]; const int bw_num_units = bw_input_weights->dims->data[0]; TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]); @@ -237,8 +240,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { // Resize outputs. TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor); TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3); - fw_output_size_array->data[0] = batch_size; - fw_output_size_array->data[1] = max_time; + fw_output_size_array->data[0] = (time_major) ? max_time : batch_size; + fw_output_size_array->data[1] = (time_major) ? batch_size : max_time; fw_output_size_array->data[2] = params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; TF_LITE_ENSURE_OK( @@ -266,8 +269,11 @@ TfLiteStatus EvalFloat( const TfLiteBidirectionalSequenceRNNParams* params, TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int input_size = input->dims->data[2]; const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; @@ -292,48 +298,91 @@ TfLiteStatus EvalFloat( params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; - for (int b = 0; b < batch_size; b++) { + if (time_major) { + // TODO(mirkov): add merge_outputs support for time_major inputs. + TF_LITE_ASSERT_EQ(params->merge_outputs, false); + // Forward cell. - float* fw_hidden_state_ptr_batch = - fw_hidden_state->data.f + b * fw_num_units; - float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; + float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; for (int s = 0; s < max_time; s++) { const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; + input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size + ? aux_input->data.f + s * input_size * batch_size : nullptr; - float* output_ptr_batch = fw_output_offset + s * fw_output_step; + float* output_ptr_batch = + fw_output->data.f + s * fw_num_units * batch_size; kernel_utils::RnnBatchStep( input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, - input_size, aux_input_size, fw_num_units, /*batch_size=*/1, + input_size, aux_input_size, fw_num_units, batch_size, params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); } // Backward cell. - float* bw_hidden_state_ptr_batch = - bw_hidden_state->data.f + b * bw_num_units; - float* bw_output_offset = - params->merge_outputs - ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units - : bw_output->data.f + b * bw_output_step * max_time; + float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; for (int s = max_time - 1; s >= 0; s--) { const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; + input->data.f + s * input_size * batch_size; const float* aux_input_ptr_batch = (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size + ? aux_input->data.f + s * input_size * batch_size : nullptr; - float* output_ptr_batch = bw_output_offset + s * bw_output_step; + float* output_ptr_batch = + bw_output->data.f + s * bw_num_units * batch_size; kernel_utils::RnnBatchStep( input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, - input_size, aux_input_size, bw_num_units, /*batch_size=*/1, + input_size, aux_input_size, bw_num_units, batch_size, params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); } + } else { + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = + fw_output->data.f + b * fw_output_step * max_time; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch, + fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr, + input_size, aux_input_size, fw_num_units, /*batch_size=*/1, + params->activation, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units + : bw_output->data.f + b * bw_output_step * max_time; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch, + bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr, + input_size, aux_input_size, bw_num_units, /*batch_size=*/1, + params->activation, bw_hidden_state_ptr_batch, output_ptr_batch); + } + } } return kTfLiteOk; } @@ -351,8 +400,11 @@ TfLiteStatus EvalHybrid( TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output, TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) { - const int batch_size = input->dims->data[0]; - const int max_time = input->dims->data[1]; + const bool time_major = params->time_major; + const int batch_size = + (time_major) ? input->dims->data[1] : input->dims->data[0]; + const int max_time = + (time_major) ? input->dims->data[0] : input->dims->data[1]; const int input_size = input->dims->data[2]; const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0; @@ -403,55 +455,106 @@ TfLiteStatus EvalHybrid( params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units; const int bw_output_step = params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units; - for (int b = 0; b < batch_size; b++) { - // Forward cell. - float* fw_hidden_state_ptr_batch = - fw_hidden_state->data.f + b * fw_num_units; - float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time; - for (int s = 0; s < max_time; s++) { - const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; - const float* aux_input_ptr_batch = - (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size - : nullptr; - float* output_ptr_batch = fw_output_offset + s * fw_output_step; - - kernel_utils::RnnBatchStep( - input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, - aux_input_ptr_batch, aux_fw_input_weights_ptr, - aux_fw_input_weights_scale, fw_recurrent_weights_ptr, - fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, - fw_num_units, /*batch_size=*/1, params->activation, - quantized_input_ptr, aux_quantized_input_ptr, - fw_quantized_hidden_state_ptr, scaling_factors_ptr, - fw_hidden_state_ptr_batch, output_ptr_batch); + if (time_major) { + for (int t = 0; t < max_time; t++) { + // TODO(mirkov): add merge_outputs support for time_major inputs. + TF_LITE_ASSERT_EQ(params->merge_outputs, false); + + // Forward cell. + float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + s * input_size * batch_size + : nullptr; + float* output_ptr_batch = + fw_output->data.f + s * fw_num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, + aux_input_ptr_batch, aux_fw_input_weights_ptr, + aux_fw_input_weights_scale, fw_recurrent_weights_ptr, + fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, + fw_num_units, batch_size, params->activation, quantized_input_ptr, + aux_quantized_input_ptr, fw_quantized_hidden_state_ptr, + scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + s * input_size * batch_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + s * input_size * batch_size + : nullptr; + float* output_ptr_batch = + bw_output->data.f + s * bw_num_units * batch_size; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, + aux_input_ptr_batch, aux_bw_input_weights_ptr, + aux_bw_input_weights_scale, bw_recurrent_weights_ptr, + bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, + bw_num_units, batch_size, params->activation, quantized_input_ptr, + aux_quantized_input_ptr, bw_quantized_hidden_state_ptr, + scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch); + } } - // Backward cell. - float* bw_hidden_state_ptr_batch = - bw_hidden_state->data.f + b * bw_num_units; - float* bw_output_offset = - params->merge_outputs - ? fw_output->data.f + b * bw_output_step * max_time - : bw_output->data.f + b * bw_output_step * max_time; - for (int s = max_time - 1; s >= 0; s--) { - const float* input_ptr_batch = - input->data.f + b * input_size * max_time + s * input_size; - const float* aux_input_ptr_batch = - (aux_input != nullptr) - ? aux_input->data.f + b * input_size * max_time + s * input_size - : nullptr; - float* output_ptr_batch = bw_output_offset + s * bw_output_step; - - kernel_utils::RnnBatchStep( - input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, - aux_input_ptr_batch, aux_bw_input_weights_ptr, - aux_bw_input_weights_scale, bw_recurrent_weights_ptr, - bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, - bw_num_units, /*batch_size=*/1, params->activation, - quantized_input_ptr, aux_quantized_input_ptr, - bw_quantized_hidden_state_ptr, scaling_factors_ptr, - bw_hidden_state_ptr_batch, output_ptr_batch); + } else { + for (int b = 0; b < batch_size; b++) { + // Forward cell. + float* fw_hidden_state_ptr_batch = + fw_hidden_state->data.f + b * fw_num_units; + float* fw_output_offset = + fw_output->data.f + b * fw_output_step * max_time; + for (int s = 0; s < max_time; s++) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = fw_output_offset + s * fw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale, + aux_input_ptr_batch, aux_fw_input_weights_ptr, + aux_fw_input_weights_scale, fw_recurrent_weights_ptr, + fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size, + fw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, aux_quantized_input_ptr, + fw_quantized_hidden_state_ptr, scaling_factors_ptr, + fw_hidden_state_ptr_batch, output_ptr_batch); + } + // Backward cell. + float* bw_hidden_state_ptr_batch = + bw_hidden_state->data.f + b * bw_num_units; + float* bw_output_offset = + params->merge_outputs + ? fw_output->data.f + b * bw_output_step * max_time + : bw_output->data.f + b * bw_output_step * max_time; + for (int s = max_time - 1; s >= 0; s--) { + const float* input_ptr_batch = + input->data.f + b * input_size * max_time + s * input_size; + const float* aux_input_ptr_batch = + (aux_input != nullptr) + ? aux_input->data.f + b * input_size * max_time + s * input_size + : nullptr; + float* output_ptr_batch = bw_output_offset + s * bw_output_step; + + kernel_utils::RnnBatchStep( + input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale, + aux_input_ptr_batch, aux_bw_input_weights_ptr, + aux_bw_input_weights_scale, bw_recurrent_weights_ptr, + bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size, + bw_num_units, /*batch_size=*/1, params->activation, + quantized_input_ptr, aux_quantized_input_ptr, + bw_quantized_hidden_state_ptr, scaling_factors_ptr, + bw_hidden_state_ptr_batch, output_ptr_batch); + } } } return kTfLiteOk; diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index f555c472f5..6c179ca05d 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -654,7 +654,8 @@ const std::initializer_list recurrent_weights = { class BidirectionalRNNOpModel : public SingleOpModel { public: BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, - int bw_units, int input_size, bool merge_outputs) + int bw_units, int input_size, bool time_major, + bool merge_outputs) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), @@ -679,25 +680,29 @@ class BidirectionalRNNOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); } - SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_BidirectionalSequenceRNNOptions, - CreateBidirectionalSequenceRNNOptions( - builder_, /*time_major=*/false, - ActivationFunctionType_RELU, merge_outputs) - .Union()); + SetBuiltinOp( + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, time_major, ActivationFunctionType_RELU, merge_outputs) + .Union()); + const auto input_shape = + (time_major) ? std::vector({sequence_len_, batches_, input_size_}) + : std::vector({batches_, sequence_len_, input_size_}); + BuildInterpreter({ - {batches_, sequence_len_, input_size_}, // input - {fw_units_, input_size_}, // fw_weights - {fw_units_, fw_units_}, // fw_recurrent_weights - {fw_units_}, // fw_bias - {batches_, fw_units_}, // fw_hidden_state - {bw_units_, input_size_}, // bw_weights - {bw_units_, bw_units_}, // bw_recurrent_weights - {bw_units_}, // bw_bias - {batches_, bw_units_}, // bw_hidden_state - {batches_, sequence_len_, 0}, // aux_input - {fw_units_, 0}, // aux_fw_weights - {bw_units_, 0}, // aux_bw_weights + input_shape, // input + {fw_units_, input_size_}, // fw_weights + {fw_units_, fw_units_}, // fw_recurrent_weights + {fw_units_}, // fw_bias + {batches_, fw_units_}, // fw_hidden_state + {bw_units_, input_size_}, // bw_weights + {bw_units_, bw_units_}, // bw_recurrent_weights + {bw_units_}, // bw_bias + {batches_, bw_units_}, // bw_hidden_state + {batches_, sequence_len_, 0}, // aux_input + {fw_units_, 0}, // aux_fw_weights + {bw_units_, 0}, // aux_bw_weights }); } @@ -770,7 +775,8 @@ class BidirectionalRNNOpModel : public SingleOpModel { TEST(BidirectionalRNNOpTest, BlackBoxTest) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -803,11 +809,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); } -// Same as the previous test, yet with merged outputs. +// Same as BlackBox test, but input is reshuffled to time_major format. +TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*time_major=*/true, + /*merge_outputs=*/false); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + // const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + // Insert the inputs in time_major format. The batch_major format is: + // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: + // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector fw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + } + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); +} + +// Same as BlackBox test, yet with merged outputs. TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/true); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/true); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -845,7 +889,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -891,7 +936,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { TEST(BidirectionalRNNOpTest, EndToEndTest) { BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); const int output_size = 4; float dnn_weights[] = { -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, -- cgit v1.2.3