aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-03 13:25:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 13:32:42 -0700
commitc2c8cfe22492cf7fab804d32283b623632270035 (patch)
tree6003bf547117f97cd65ed598c4cec39cba7d5510 /tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
parent7566f3d5ad690c71c36e78611b1ae5913ec3e845 (diff)
Add the option of merging bidirectional RNN and LSTM outputs into a single output tensor.
This is useful if the output of both directions will be passed to the next layer as a single output, as it avoids adding a concatenation op, which can be expensive on mobile devices where memory movement is relatively expensive. PiperOrigin-RevId: 215616140
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc56
1 files changed, 48 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index 3e34ba6196..f555c472f5 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,7 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size)
+ int bw_units, int input_size, bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -675,12 +675,15 @@ class BidirectionalRNNOpModel : public SingleOpModel {
aux_bw_weights_ = AddNullInput();
fw_output_ = AddOutput(TensorType_FLOAT32);
- bw_output_ = AddOutput(TensorType_FLOAT32);
+ if (!merge_outputs) {
+ bw_output_ = AddOutput(TensorType_FLOAT32);
+ }
SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_SequenceRNNOptions,
- CreateSequenceRNNOptions(builder_, /*time_major=*/false,
- ActivationFunctionType_RELU)
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, /*time_major=*/false,
+ ActivationFunctionType_RELU, merge_outputs)
.Union());
BuildInterpreter({
{batches_, sequence_len_, input_size_}, // input
@@ -767,7 +770,7 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -800,12 +803,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
+// Same as the previous test, yet with merged outputs.
+TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*merge_outputs=*/true);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ float* batch_start = rnn_input;
+ float* batch_end = batch_start + input_sequence_size;
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(input_sequence_size, batch_start, batch_end);
+
+ rnn.Invoke();
+
+ std::vector<float> merged_expected;
+ for (int bid = 0; bid < rnn.num_batches(); bid++) {
+ for (int step = 0; step < rnn.sequence_len(); step++) {
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_fw_output + rnn.num_fw_units() * step,
+ rnn_golden_fw_output + rnn.num_fw_units() * (step + 1));
+ merged_expected.insert(
+ merged_expected.end(),
+ rnn_golden_bw_output + rnn.num_bw_units() * step,
+ rnn_golden_bw_output + rnn.num_bw_units() * (step + 1));
+ }
+ }
+ EXPECT_THAT(rnn.GetFwOutput(),
+ ElementsAreArray(ArrayFloatNear(merged_expected)));
+}
+
// Check that if the input sequence is reversed the outputs are the same just
// forward and backward are swapped (and reversed).
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -851,7 +891,7 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8);
+ /*input_size=*/8, /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,