aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-09 14:12:25 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 14:23:18 -0700
commitb145f46b735fe1e383be6629cafaa5269b07b7fb (patch)
tree4dc2b50afcc0af01a1ea9d9502672522ca04aed1
parent4fa59ef694c19dc63d574b2d6a349cd753d9cdbd (diff)
Add support for time-major input in the bidirectional RNN Op.
PiperOrigin-RevId: 216419983
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc251
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc94
2 files changed, 247 insertions, 98 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
index c22a457a71..f544dd5ffa 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn.cc
@@ -114,8 +114,11 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
TF_LITE_ENSURE_EQ(context, input->type, kTfLiteFloat32);
TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int fw_num_units = fw_input_weights->dims->data[0];
const int bw_num_units = bw_input_weights->dims->data[0];
TF_LITE_ASSERT_EQ(input->dims->data[2], fw_input_weights->dims->data[1]);
@@ -237,8 +240,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
// Resize outputs.
TfLiteTensor* fw_output = GetOutput(context, node, kFwOutputTensor);
TfLiteIntArray* fw_output_size_array = TfLiteIntArrayCreate(3);
- fw_output_size_array->data[0] = batch_size;
- fw_output_size_array->data[1] = max_time;
+ fw_output_size_array->data[0] = (time_major) ? max_time : batch_size;
+ fw_output_size_array->data[1] = (time_major) ? batch_size : max_time;
fw_output_size_array->data[2] =
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
TF_LITE_ENSURE_OK(
@@ -266,8 +269,11 @@ TfLiteStatus EvalFloat(
const TfLiteBidirectionalSequenceRNNParams* params,
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state, TfLiteTensor* bw_output) {
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int input_size = input->dims->data[2];
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
@@ -292,48 +298,91 @@ TfLiteStatus EvalFloat(
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
- for (int b = 0; b < batch_size; b++) {
+ if (time_major) {
+ // TODO(mirkov): add merge_outputs support for time_major inputs.
+ TF_LITE_ASSERT_EQ(params->merge_outputs, false);
+
// Forward cell.
- float* fw_hidden_state_ptr_batch =
- fw_hidden_state->data.f + b * fw_num_units;
- float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
+ float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
for (int s = 0; s < max_time; s++) {
const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
+ input->data.f + s * input_size * batch_size;
const float* aux_input_ptr_batch =
(aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
+ ? aux_input->data.f + s * input_size * batch_size
: nullptr;
- float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+ float* output_ptr_batch =
+ fw_output->data.f + s * fw_num_units * batch_size;
kernel_utils::RnnBatchStep(
input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
- input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
+ input_size, aux_input_size, fw_num_units, batch_size,
params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
}
// Backward cell.
- float* bw_hidden_state_ptr_batch =
- bw_hidden_state->data.f + b * bw_num_units;
- float* bw_output_offset =
- params->merge_outputs
- ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
- : bw_output->data.f + b * bw_output_step * max_time;
+ float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
for (int s = max_time - 1; s >= 0; s--) {
const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
+ input->data.f + s * input_size * batch_size;
const float* aux_input_ptr_batch =
(aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
+ ? aux_input->data.f + s * input_size * batch_size
: nullptr;
- float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+ float* output_ptr_batch =
+ bw_output->data.f + s * bw_num_units * batch_size;
kernel_utils::RnnBatchStep(
input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
- input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
+ input_size, aux_input_size, bw_num_units, batch_size,
params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
}
+ } else {
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset =
+ fw_output->data.f + b * fw_output_step * max_time;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, aux_input_ptr_batch,
+ fw_aux_input_weights_ptr, fw_recurrent_weights_ptr, fw_bias_ptr,
+ input_size, aux_input_size, fw_num_units, /*batch_size=*/1,
+ params->activation, fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time + fw_num_units
+ : bw_output->data.f + b * bw_output_step * max_time;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, aux_input_ptr_batch,
+ bw_aux_input_weights_ptr, bw_recurrent_weights_ptr, bw_bias_ptr,
+ input_size, aux_input_size, bw_num_units, /*batch_size=*/1,
+ params->activation, bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ }
}
return kTfLiteOk;
}
@@ -351,8 +400,11 @@ TfLiteStatus EvalHybrid(
TfLiteTensor* fw_hidden_state, TfLiteTensor* fw_output,
TfLiteTensor* bw_hidden_state_quantized, TfLiteTensor* bw_hidden_state,
TfLiteTensor* bw_output) {
- const int batch_size = input->dims->data[0];
- const int max_time = input->dims->data[1];
+ const bool time_major = params->time_major;
+ const int batch_size =
+ (time_major) ? input->dims->data[1] : input->dims->data[0];
+ const int max_time =
+ (time_major) ? input->dims->data[0] : input->dims->data[1];
const int input_size = input->dims->data[2];
const int aux_input_size = (aux_input) ? aux_input->dims->data[2] : 0;
@@ -403,55 +455,106 @@ TfLiteStatus EvalHybrid(
params->merge_outputs ? fw_num_units + bw_num_units : fw_num_units;
const int bw_output_step =
params->merge_outputs ? fw_num_units + bw_num_units : bw_num_units;
- for (int b = 0; b < batch_size; b++) {
- // Forward cell.
- float* fw_hidden_state_ptr_batch =
- fw_hidden_state->data.f + b * fw_num_units;
- float* fw_output_offset = fw_output->data.f + b * fw_output_step * max_time;
- for (int s = 0; s < max_time; s++) {
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- const float* aux_input_ptr_batch =
- (aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
- : nullptr;
- float* output_ptr_batch = fw_output_offset + s * fw_output_step;
-
- kernel_utils::RnnBatchStep(
- input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
- aux_input_ptr_batch, aux_fw_input_weights_ptr,
- aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
- fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
- fw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, aux_quantized_input_ptr,
- fw_quantized_hidden_state_ptr, scaling_factors_ptr,
- fw_hidden_state_ptr_batch, output_ptr_batch);
+ if (time_major) {
+ for (int t = 0; t < max_time; t++) {
+ // TODO(mirkov): add merge_outputs support for time_major inputs.
+ TF_LITE_ASSERT_EQ(params->merge_outputs, false);
+
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch = fw_hidden_state->data.f;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + s * input_size * batch_size
+ : nullptr;
+ float* output_ptr_batch =
+ fw_output->data.f + s * fw_num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, batch_size, params->activation, quantized_input_ptr,
+ aux_quantized_input_ptr, fw_quantized_hidden_state_ptr,
+ scaling_factors_ptr, fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch = bw_hidden_state->data.f;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + s * input_size * batch_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + s * input_size * batch_size
+ : nullptr;
+ float* output_ptr_batch =
+ bw_output->data.f + s * bw_num_units * batch_size;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, batch_size, params->activation, quantized_input_ptr,
+ aux_quantized_input_ptr, bw_quantized_hidden_state_ptr,
+ scaling_factors_ptr, bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
}
- // Backward cell.
- float* bw_hidden_state_ptr_batch =
- bw_hidden_state->data.f + b * bw_num_units;
- float* bw_output_offset =
- params->merge_outputs
- ? fw_output->data.f + b * bw_output_step * max_time
- : bw_output->data.f + b * bw_output_step * max_time;
- for (int s = max_time - 1; s >= 0; s--) {
- const float* input_ptr_batch =
- input->data.f + b * input_size * max_time + s * input_size;
- const float* aux_input_ptr_batch =
- (aux_input != nullptr)
- ? aux_input->data.f + b * input_size * max_time + s * input_size
- : nullptr;
- float* output_ptr_batch = bw_output_offset + s * bw_output_step;
-
- kernel_utils::RnnBatchStep(
- input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
- aux_input_ptr_batch, aux_bw_input_weights_ptr,
- aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
- bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
- bw_num_units, /*batch_size=*/1, params->activation,
- quantized_input_ptr, aux_quantized_input_ptr,
- bw_quantized_hidden_state_ptr, scaling_factors_ptr,
- bw_hidden_state_ptr_batch, output_ptr_batch);
+ } else {
+ for (int b = 0; b < batch_size; b++) {
+ // Forward cell.
+ float* fw_hidden_state_ptr_batch =
+ fw_hidden_state->data.f + b * fw_num_units;
+ float* fw_output_offset =
+ fw_output->data.f + b * fw_output_step * max_time;
+ for (int s = 0; s < max_time; s++) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = fw_output_offset + s * fw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, fw_input_weights_ptr, fw_input_weights_scale,
+ aux_input_ptr_batch, aux_fw_input_weights_ptr,
+ aux_fw_input_weights_scale, fw_recurrent_weights_ptr,
+ fw_recurrent_weights_scale, fw_bias_ptr, input_size, aux_input_size,
+ fw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ fw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ fw_hidden_state_ptr_batch, output_ptr_batch);
+ }
+ // Backward cell.
+ float* bw_hidden_state_ptr_batch =
+ bw_hidden_state->data.f + b * bw_num_units;
+ float* bw_output_offset =
+ params->merge_outputs
+ ? fw_output->data.f + b * bw_output_step * max_time
+ : bw_output->data.f + b * bw_output_step * max_time;
+ for (int s = max_time - 1; s >= 0; s--) {
+ const float* input_ptr_batch =
+ input->data.f + b * input_size * max_time + s * input_size;
+ const float* aux_input_ptr_batch =
+ (aux_input != nullptr)
+ ? aux_input->data.f + b * input_size * max_time + s * input_size
+ : nullptr;
+ float* output_ptr_batch = bw_output_offset + s * bw_output_step;
+
+ kernel_utils::RnnBatchStep(
+ input_ptr_batch, bw_input_weights_ptr, bw_input_weights_scale,
+ aux_input_ptr_batch, aux_bw_input_weights_ptr,
+ aux_bw_input_weights_scale, bw_recurrent_weights_ptr,
+ bw_recurrent_weights_scale, bw_bias_ptr, input_size, aux_input_size,
+ bw_num_units, /*batch_size=*/1, params->activation,
+ quantized_input_ptr, aux_quantized_input_ptr,
+ bw_quantized_hidden_state_ptr, scaling_factors_ptr,
+ bw_hidden_state_ptr_batch, output_ptr_batch);
+ }
}
}
return kTfLiteOk;
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index f555c472f5..6c179ca05d 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,8 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size, bool merge_outputs)
+ int bw_units, int input_size, bool time_major,
+ bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -679,25 +680,29 @@ class BidirectionalRNNOpModel : public SingleOpModel {
bw_output_ = AddOutput(TensorType_FLOAT32);
}
- SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_BidirectionalSequenceRNNOptions,
- CreateBidirectionalSequenceRNNOptions(
- builder_, /*time_major=*/false,
- ActivationFunctionType_RELU, merge_outputs)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
+ .Union());
+ const auto input_shape =
+ (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
+ : std::vector<int>({batches_, sequence_len_, input_size_});
+
BuildInterpreter({
- {batches_, sequence_len_, input_size_}, // input
- {fw_units_, input_size_}, // fw_weights
- {fw_units_, fw_units_}, // fw_recurrent_weights
- {fw_units_}, // fw_bias
- {batches_, fw_units_}, // fw_hidden_state
- {bw_units_, input_size_}, // bw_weights
- {bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_}, // bw_bias
- {batches_, bw_units_}, // bw_hidden_state
- {batches_, sequence_len_, 0}, // aux_input
- {fw_units_, 0}, // aux_fw_weights
- {bw_units_, 0}, // aux_bw_weights
+ input_shape, // input
+ {fw_units_, input_size_}, // fw_weights
+ {fw_units_, fw_units_}, // fw_recurrent_weights
+ {fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
+ {bw_units_, input_size_}, // bw_weights
+ {bw_units_, bw_units_}, // bw_recurrent_weights
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -770,7 +775,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -803,11 +809,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
-// Same as the previous test, yet with merged outputs.
+// Same as BlackBox test, but input is reshuffled to time_major format.
+TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*time_major=*/true,
+ /*merge_outputs=*/false);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ // const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ // Insert the inputs in time_major format. The batch_major format is:
+ // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
+ // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> fw_expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
+ float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
+ fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+ fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+ }
+ EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+}
+
+// Same as BlackBox test, yet with merged outputs.
TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/true);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/true);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -845,7 +889,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -891,7 +936,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,