diff options
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | 94 |
1 files changed, 70 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index f555c472f5..6c179ca05d 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -654,7 +654,8 @@ const std::initializer_list<float> recurrent_weights = { class BidirectionalRNNOpModel : public SingleOpModel { public: BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, - int bw_units, int input_size, bool merge_outputs) + int bw_units, int input_size, bool time_major, + bool merge_outputs) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), @@ -679,25 +680,29 @@ class BidirectionalRNNOpModel : public SingleOpModel { bw_output_ = AddOutput(TensorType_FLOAT32); } - SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_BidirectionalSequenceRNNOptions, - CreateBidirectionalSequenceRNNOptions( - builder_, /*time_major=*/false, - ActivationFunctionType_RELU, merge_outputs) - .Union()); + SetBuiltinOp( + BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, time_major, ActivationFunctionType_RELU, merge_outputs) + .Union()); + const auto input_shape = + (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_}) + : std::vector<int>({batches_, sequence_len_, input_size_}); + BuildInterpreter({ - {batches_, sequence_len_, input_size_}, // input - {fw_units_, input_size_}, // fw_weights - {fw_units_, fw_units_}, // fw_recurrent_weights - {fw_units_}, // fw_bias - {batches_, fw_units_}, // fw_hidden_state - {bw_units_, input_size_}, // bw_weights - {bw_units_, bw_units_}, // bw_recurrent_weights - {bw_units_}, // bw_bias - {batches_, bw_units_}, // bw_hidden_state - {batches_, sequence_len_, 0}, // aux_input - {fw_units_, 0}, // aux_fw_weights - {bw_units_, 0}, // aux_bw_weights + input_shape, // input + {fw_units_, input_size_}, // fw_weights + {fw_units_, fw_units_}, // fw_recurrent_weights + {fw_units_}, // fw_bias + {batches_, fw_units_}, // fw_hidden_state + {bw_units_, input_size_}, // bw_weights + {bw_units_, bw_units_}, // bw_recurrent_weights + {bw_units_}, // bw_bias + {batches_, bw_units_}, // bw_hidden_state + {batches_, sequence_len_, 0}, // aux_input + {fw_units_, 0}, // aux_fw_weights + {bw_units_, 0}, // aux_bw_weights }); } @@ -770,7 +775,8 @@ class BidirectionalRNNOpModel : public SingleOpModel { TEST(BidirectionalRNNOpTest, BlackBoxTest) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -803,11 +809,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); } -// Same as the previous test, yet with merged outputs. +// Same as BlackBox test, but input is reshuffled to time_major format. +TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*time_major=*/true, + /*merge_outputs=*/false); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + // const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + // Insert the inputs in time_major format. The batch_major format is: + // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as: + // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15]. + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector<float> fw_expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units(); + float* golden_fw_end = golden_fw_start + rnn.num_fw_units(); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end); + } + EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected))); +} + +// Same as BlackBox test, yet with merged outputs. TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/true); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/true); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -845,7 +889,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -891,7 +936,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { TEST(BidirectionalRNNOpTest, EndToEndTest) { BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8, /*merge_outputs=*/false); + /*input_size=*/8, /*time_major=*/false, + /*merge_outputs=*/false); const int output_size = 4; float dnn_weights[] = { -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, |