diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-05-10 12:14:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-05-10 12:17:27 -0700 |
commit | bd95d55a2886677ba194351197d93c8b1408cc85 (patch) | |
tree | dab3692368df669482f035bbd97726a1980bca37 /tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | |
parent | 3ffa132c03ff02decc86a31d8bf888e9381278a7 (diff) |
Implementation of the unidirectional_sequence_rnn TFLite Op using the symmetric quantization.
PiperOrigin-RevId: 196152754
Diffstat (limited to 'tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc | 243 |
1 files changed, 141 insertions, 102 deletions
diff --git a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc index 7e32969763..0adab837b0 100644 --- a/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc +++ b/tensorflow/contrib/lite/kernels/unidirectional_sequence_rnn_test.cc @@ -122,17 +122,66 @@ static float rnn_golden_output[] = { 0, 2.02616, 0, 0.728256, 0.84183, 0.0907453, 0.628881, 3.58099, 1.49974, 0}; +static std::initializer_list<float> rnn_weights = { + 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, + 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, + 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, + -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, + -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, + -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, + -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, + 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, + 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, + 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, + -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, + 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, + -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, + -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, + 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, + 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, + 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, + -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, + 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, + 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, + -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, + 0.277308, 0.415818}; + +static std::initializer_list<float> rnn_recurrent_weights = { + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0.1}; + +static std::initializer_list<float> rnn_bias = { + 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568, + -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178, + 0.37197268, 0.61957061, 0.3956964, -0.37609905}; + class UnidirectionalRNNOpModel : public SingleOpModel { public: - UnidirectionalRNNOpModel(int batches, int sequence_len, int units, int size, - bool time_major) + UnidirectionalRNNOpModel( + int batches, int sequence_len, int units, int size, bool time_major, + const TensorType& weights = TensorType_FLOAT32, + const TensorType& recurrent_weights = TensorType_FLOAT32) : batches_(batches), sequence_len_(sequence_len), units_(units), input_size_(size) { input_ = AddInput(TensorType_FLOAT32); - weights_ = AddInput(TensorType_FLOAT32); - recurrent_weights_ = AddInput(TensorType_FLOAT32); + weights_ = AddInput(weights); + recurrent_weights_ = AddInput(recurrent_weights); bias_ = AddInput(TensorType_FLOAT32); hidden_state_ = AddOutput(TensorType_FLOAT32); output_ = AddOutput(TensorType_FLOAT32); @@ -187,7 +236,7 @@ class UnidirectionalRNNOpModel : public SingleOpModel { int num_batches() { return batches_; } int sequence_len() { return sequence_len_; } - private: + protected: int input_; int weights_; int recurrent_weights_; @@ -201,58 +250,31 @@ class UnidirectionalRNNOpModel : public SingleOpModel { int input_size_; }; -// TODO(mirkov): add another test which directly compares to TF once TOCO -// supports the conversion from dynamic_rnn with BasicRNNCell. -TEST(FullyConnectedOpTest, BlackBoxTest) { +// The hybrid model has quantized weights and recurrent_weights. +class HybridUnidirectionalRNNOpModel : public UnidirectionalRNNOpModel { + public: + HybridUnidirectionalRNNOpModel(int batches, int sequence_len, int units, + int size, bool time_major) + : UnidirectionalRNNOpModel(batches, sequence_len, units, size, time_major, + TensorType_UINT8, TensorType_UINT8) {} + + void SetWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(weights_, f); + } + + void SetRecurrentWeights(std::initializer_list<float> f) { + SymmetricQuantizeAndPopulate(recurrent_weights_, f); + } +}; + +TEST(UnidirectionalRNNOpTest, BlackBoxTest) { UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, /*units=*/16, /*size=*/8, /*time_major=*/false); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); - + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.ResetHiddenState(); + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); float* batch_start = rnn_input; float* batch_end = batch_start + input_sequence_size; @@ -270,56 +292,42 @@ TEST(FullyConnectedOpTest, BlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } -TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { - UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, - /*units=*/16, /*size=*/8, /*time_major=*/true); - rnn.SetWeights( - {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346, - 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399, - 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113, - -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512, - -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188, - -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158, - -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241, - 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183, - 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303, - 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884, - -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726, - 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644, - -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461, - -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158, - 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042, - 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012, - 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345, - -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884, - 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274, - 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934, - -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077, - 0.277308, 0.415818}); - - rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, - -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796, - 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964, - -0.37609905}); - - rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, - 0.1}); +TEST(HybridUnidirectionalRNNOpModelOpTest, BlackBoxTest) { + HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/false); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + rnn.ResetHiddenState(); + + const int input_sequence_size = rnn.input_size() * rnn.sequence_len(); + float* batch_start = rnn_input; + float* batch_end = batch_start + input_sequence_size; + rnn.SetInput(0, batch_start, batch_end); + rnn.SetInput(input_sequence_size, batch_start, batch_end); + + rnn.Invoke(); + + float* golden_start = rnn_golden_output; + float* golden_end = golden_start + rnn.num_units() * rnn.sequence_len(); + std::vector<float> expected; + expected.insert(expected.end(), golden_start, golden_end); + expected.insert(expected.end(), golden_start, golden_end); + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear( + expected, /*max_abs_error=*/0.013))); +} +TEST(UnidirectionalRNNOpTest, TimeMajorBlackBoxTest) { + UnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/true); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); rnn.ResetHiddenState(); + for (int i = 0; i < rnn.sequence_len(); i++) { float* batch_start = rnn_input + i * rnn.input_size(); float* batch_end = batch_start + rnn.input_size(); @@ -341,6 +349,37 @@ TEST(FullyConnectedOpTest, TimeMajorBlackBoxTest) { EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(expected))); } +TEST(HybridUnidirectionalRNNOpModelOpTest, TimeMajorBlackBoxTest) { + HybridUnidirectionalRNNOpModel rnn(/*batches=*/2, /*sequence_len=*/16, + /*units=*/16, /*size=*/8, + /*time_major=*/true); + rnn.SetWeights(rnn_weights); + rnn.SetBias(rnn_bias); + rnn.SetRecurrentWeights(rnn_recurrent_weights); + rnn.ResetHiddenState(); + + for (int i = 0; i < rnn.sequence_len(); i++) { + float* batch_start = rnn_input + i * rnn.input_size(); + float* batch_end = batch_start + rnn.input_size(); + // The two batches are identical. + rnn.SetInput(2 * i * rnn.input_size(), batch_start, batch_end); + rnn.SetInput((2 * i + 1) * rnn.input_size(), batch_start, batch_end); + } + + rnn.Invoke(); + + std::vector<float> expected; + for (int i = 0; i < rnn.sequence_len(); i++) { + float* golden_batch_start = rnn_golden_output + i * rnn.num_units(); + float* golden_batch_end = golden_batch_start + rnn.num_units(); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + expected.insert(expected.end(), golden_batch_start, golden_batch_end); + } + + EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear( + expected, /*max_abs_error=*/0.013))); +} + } // namespace } // namespace tflite |