diff options
author | 2018-10-03 13:25:22 -0700 | |
---|---|---|
committer | 2018-10-03 13:32:42 -0700 | |
commit | c2c8cfe22492cf7fab804d32283b623632270035 (patch) | |
tree | 6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | |
parent | 7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff) |
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive.
PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc | 56 |
1 files changed, 48 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc index 3e34ba6196..f555c472f5 100644 --- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc @@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = { class BidirectionalRNNOpModel : public SingleOpModel { public: BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units, - int bw_units, int input_size) + int bw_units, int input_size, bool merge_outputs) : batches_(batches), sequence_len_(sequence_len), fw_units_(fw_units), @@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel { aux_bw_weights_ = AddNullInput(); fw_output_ = AddOutput(TensorType_FLOAT32); - bw_output_ = AddOutput(TensorType_FLOAT32); + if (!merge_outputs) { + bw_output_ = AddOutput(TensorType_FLOAT32); + } SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN, - BuiltinOptions_SequenceRNNOptions, - CreateSequenceRNNOptions(builder_, /*time_major=*/false, - ActivationFunctionType_RELU) + BuiltinOptions_BidirectionalSequenceRNNOptions, + CreateBidirectionalSequenceRNNOptions( + builder_, /*time_major=*/false, + ActivationFunctionType_RELU, merge_outputs) .Union()); BuildInterpreter({ {batches_, sequence_len_, input_size_}, // input @@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel { TEST(BidirectionalRNNOpTest, BlackBoxTest) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected))); } +// Same as the previous test, yet with merged outputs. +TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) { + BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*fw_units=*/16, /*bw_units=*/16, + /*input_size=*/8, /*merge_outputs=*/true); + rnn.SetFwWeights(weights); + rnn.SetBwWeights(weights); + rnn.SetFwBias(biases); + rnn.SetBwBias(biases); + rnn.SetFwRecurrentWeights(recurrent_weights); + rnn.SetBwRecurrentWeights(recurrent_weights); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + std::vector<float> merged_expected; + for (int bid = 0; bid < rnn.num_batches(); bid++) { + for (int step = 0; step < rnn.sequence_len(); step++) { + merged_expected.insert( + merged_expected.end(), + rnn_golden_fw_output + rnn.num_fw_units() * step, + rnn_golden_fw_output + rnn.num_fw_units() * (step + 1)); + merged_expected.insert( + merged_expected.end(), + rnn_golden_bw_output + rnn.num_bw_units() * step, + rnn_golden_bw_output + rnn.num_bw_units() * (step + 1)); + } + } + EXPECT_THAT(rnn.GetFwOutput(), + ElementsAreArray(ArrayFloatNear(merged_expected))); +} + // Check that if the input sequence is reversed the outputs are the same just // forward and backward are swapped (and reversed). TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); rnn.SetFwWeights(weights); rnn.SetBwWeights(weights); rnn.SetFwBias(biases); @@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) { TEST(BidirectionalRNNOpTest, EndToEndTest) { BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4, /*fw_units=*/16, /*bw_units=*/16, - /*input_size=*/8); + /*input_size=*/8, /*merge_outputs=*/false); const int output_size = 4; float dnn_weights[] = { -0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139, |