aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc94
1 files changed, 70 insertions, 24 deletions
diff --git a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
index f555c472f5..6c179ca05d 100644
--- a/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/bidirectional_sequence_rnn_test.cc
@@ -654,7 +654,8 @@ const std::initializer_list<float> recurrent_weights = {
class BidirectionalRNNOpModel : public SingleOpModel {
public:
BidirectionalRNNOpModel(int batches, int sequence_len, int fw_units,
- int bw_units, int input_size, bool merge_outputs)
+ int bw_units, int input_size, bool time_major,
+ bool merge_outputs)
: batches_(batches),
sequence_len_(sequence_len),
fw_units_(fw_units),
@@ -679,25 +680,29 @@ class BidirectionalRNNOpModel : public SingleOpModel {
bw_output_ = AddOutput(TensorType_FLOAT32);
}
- SetBuiltinOp(BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
- BuiltinOptions_BidirectionalSequenceRNNOptions,
- CreateBidirectionalSequenceRNNOptions(
- builder_, /*time_major=*/false,
- ActivationFunctionType_RELU, merge_outputs)
- .Union());
+ SetBuiltinOp(
+ BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN,
+ BuiltinOptions_BidirectionalSequenceRNNOptions,
+ CreateBidirectionalSequenceRNNOptions(
+ builder_, time_major, ActivationFunctionType_RELU, merge_outputs)
+ .Union());
+ const auto input_shape =
+ (time_major) ? std::vector<int>({sequence_len_, batches_, input_size_})
+ : std::vector<int>({batches_, sequence_len_, input_size_});
+
BuildInterpreter({
- {batches_, sequence_len_, input_size_}, // input
- {fw_units_, input_size_}, // fw_weights
- {fw_units_, fw_units_}, // fw_recurrent_weights
- {fw_units_}, // fw_bias
- {batches_, fw_units_}, // fw_hidden_state
- {bw_units_, input_size_}, // bw_weights
- {bw_units_, bw_units_}, // bw_recurrent_weights
- {bw_units_}, // bw_bias
- {batches_, bw_units_}, // bw_hidden_state
- {batches_, sequence_len_, 0}, // aux_input
- {fw_units_, 0}, // aux_fw_weights
- {bw_units_, 0}, // aux_bw_weights
+ input_shape, // input
+ {fw_units_, input_size_}, // fw_weights
+ {fw_units_, fw_units_}, // fw_recurrent_weights
+ {fw_units_}, // fw_bias
+ {batches_, fw_units_}, // fw_hidden_state
+ {bw_units_, input_size_}, // bw_weights
+ {bw_units_, bw_units_}, // bw_recurrent_weights
+ {bw_units_}, // bw_bias
+ {batches_, bw_units_}, // bw_hidden_state
+ {batches_, sequence_len_, 0}, // aux_input
+ {fw_units_, 0}, // aux_fw_weights
+ {bw_units_, 0}, // aux_bw_weights
});
}
@@ -770,7 +775,8 @@ class BidirectionalRNNOpModel : public SingleOpModel {
TEST(BidirectionalRNNOpTest, BlackBoxTest) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -803,11 +809,49 @@ TEST(BidirectionalRNNOpTest, BlackBoxTest) {
EXPECT_THAT(rnn.GetBwOutput(), ElementsAreArray(ArrayFloatNear(bw_expected)));
}
-// Same as the previous test, yet with merged outputs.
+// Same as BlackBox test, but input is reshuffled to time_major format.
+TEST(BidirectionalRNNOpTest, BlackBoxTestTimeMajor) {
+ BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
+ /*fw_units=*/16, /*bw_units=*/16,
+ /*input_size=*/8, /*time_major=*/true,
+ /*merge_outputs=*/false);
+ rnn.SetFwWeights(weights);
+ rnn.SetBwWeights(weights);
+ rnn.SetFwBias(biases);
+ rnn.SetBwBias(biases);
+ rnn.SetFwRecurrentWeights(recurrent_weights);
+ rnn.SetBwRecurrentWeights(recurrent_weights);
+
+ // const int input_sequence_size = rnn.input_size() * rnn.sequence_len();
+ // Insert the inputs in time_major format. The batch_major format is:
+ // [b0t0, b0t1, ..., b0t15, b1t0, b1t1, ..., b1t15]. This is reshuffled as:
+ // [b0t0, b1t0, b0t1, b1t1, ..., b0t15, b1t15].
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ // The two batches are identical.
+ rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end);
+ rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end);
+ }
+
+ rnn.Invoke();
+
+ std::vector<float> fw_expected;
+ for (int i = 0; i < rnn.sequence_len(); i++) {
+ float* golden_fw_start = rnn_golden_fw_output + i * rnn.num_fw_units();
+ float* golden_fw_end = golden_fw_start + rnn.num_fw_units();
+ fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+ fw_expected.insert(fw_expected.end(), golden_fw_start, golden_fw_end);
+ }
+ EXPECT_THAT(rnn.GetFwOutput(), ElementsAreArray(ArrayFloatNear(fw_expected)));
+}
+
+// Same as BlackBox test, yet with merged outputs.
TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/true);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/true);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -845,7 +889,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestMergeOutputs) {
TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
BidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
rnn.SetFwWeights(weights);
rnn.SetBwWeights(weights);
rnn.SetFwBias(biases);
@@ -891,7 +936,8 @@ TEST(BidirectionalRNNOpTest, BlackBoxTestReverseInputs) {
TEST(BidirectionalRNNOpTest, EndToEndTest) {
BidirectionalRNNOpModel rnn(/*batches=*/1, /*sequence_len=*/4,
/*fw_units=*/16, /*bw_units=*/16,
- /*input_size=*/8, /*merge_outputs=*/false);
+ /*input_size=*/8, /*time_major=*/false,
+ /*merge_outputs=*/false);
const int output_size = 4;
float dnn_weights[] = {
-0.5782342, -0.052212059, 0.73036242, -0.81216097, -0.80088139,