diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-19 12:08:47 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-19 12:13:53 -0800 |
commit | 83b751621439cc2b8a85450972414cf2f92a58cf (patch) | |
tree | 860f7f40a104388113121c79f7c6cb51cf7b1198 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | |
parent | 39ae44e9822ed76639bae3ccf800b36039d1da55 (diff) |
Add support for time_major shape format to the sequential RNN Op in TF Lite.
This option, if set, changes the shape format of the inputs and outputs to
[max_time, batch_size, depth]. If false, it uses [batch_size, max_time, depth].
By default, it is set to false.
PiperOrigin-RevId: 182569507
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | 102 |
1 files changed, 92 insertions, 10 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index a1c1eda160..82c680ec3d 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -// Unit test for TFLite RNN op. +// Unit test for TFLite Sequential RNN op. #include <vector> #include <iomanip> @@ -125,7 +125,8 @@ static float rnn_golden_output[] = { class UnidirectionalRNNOpModel : public SingleOpModel { public: - UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size) + UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, + bool time_major) : batches_(batches), sequence_len_(sequence_len), units_(units), @@ -136,13 +137,22 @@ class UnidirectionalRNNOpModel : public SingleOpModel { bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); - SetBuiltinOp( - BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, BuiltinOptions_RNNOptions, - CreateRNNOptions(builder_, ActivationFunctionType_RELU).Union()); - BuildInterpreter({{batches_, sequence_len_, input_size_}, - {units_, input_size_}, - {units_, units_}, - {units_}}); + SetBuiltinOp(BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN, + BuiltinOptions_SequenceRNNOptions, + CreateSequenceRNNOptions(builder_, time_major, + ActivationFunctionType_RELU) + .Union()); + if (time_major) { + BuildInterpreter({{sequence_len_, batches_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } else { + BuildInterpreter({{batches_, sequence_len_, input_size_}, + {units_, input_size_}, + {units_, units_}, + {units_}}); + } } void SetBias(std::initializer_list<float> f) { PopulateTensor(bias_, f); } @@ -195,7 +205,8 @@ class UnidirectionalRNNOpModel : public SingleOpModel { // TODO(mirkov): add another test which directly compares to TF once TOCO // supports the conversion from dynamic_rnn with BasicRNNCell. TEST(FullyConnectedOpTest, BlackBoxTest) { - UnidirectionalRNNOpModel rnn(2, 16, 16, 8); + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/false); rnn.SetWeights( {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, @@ -260,6 +271,77 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } +TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, /*time_major=*/true); + rnn.SetWeights( + {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}); + + rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, + -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, + 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, + -0.37609905}); + + rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}); + + rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector<float> expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); +} + } // namespace } // namespace tflite |