aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-10 11:22:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-10 11:26:57 -0700
commit68ee0e153c5318a79dae612647f27a31f6c2f59c (patch)
tree270011a00c12161ff30a977e1d266677f964385f /tensorflow/contrib/lite/kernels/basic_rnn_test.cc
parent0013b6953547fe17865c21155bdebe4cfe656e74 (diff)
Implementation of the basic_rnn TFLite Op using the symmetric quantization.
PiperOrigin-RevId: 196144379
Diffstat (limited to 'tensorflow/contrib/lite/kernels/basic_rnn_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/basic_rnn_test.cc155
1 files changed, 104 insertions, 51 deletions
diff --git a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
index fa7ef525db..96465fcaf0 100644
--- a/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
+++ b/tensorflow/contrib/lite/kernels/basic_rnn_test.cc
@@ -14,7 +14,9 @@ limitations under the License.
==============================================================================*/
// Unit test for TFLite RNN op.
-#include <iomanip>
+#include <string.h>
+#include <initializer_list>
+#include <memory>
#include <vector>
#include <gmock/gmock.h>
@@ -122,13 +124,62 @@ static float rnn_golden_output[] = {
0, 2.02616, 0, 0.728256, 0.84183, 0.0907453,
0.628881, 3.58099, 1.49974, 0};
+static std::initializer_list<float> rnn_weights = {
+ 0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
+ 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
+ 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
+ -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
+ -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
+ -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
+ -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
+ 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
+ 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
+ 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
+ -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
+ 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
+ -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
+ -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
+ 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
+ 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
+ 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
+ -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
+ 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
+ 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
+ -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
+ 0.277308, 0.415818};
+
+static std::initializer_list<float> rnn_recurrent_weights = {
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0.1};
+
+static std::initializer_list<float> rnn_bias = {
+ 0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068, -0.23566568,
+ -0.389184, 0.47481549, -0.4791103, 0.29931796, 0.10463274, 0.83918178,
+ 0.37197268, 0.61957061, 0.3956964, -0.37609905};
+
class RNNOpModel : public SingleOpModel {
public:
- RNNOpModel(int batches, int units, int size)
+ RNNOpModel(int batches, int units, int size,
+ const TensorType& weights = TensorType_FLOAT32,
+ const TensorType& recurrent_weights = TensorType_FLOAT32)
: batches_(batches), units_(units), input_size_(size) {
input_ = AddInput(TensorType_FLOAT32);
- weights_ = AddInput(TensorType_FLOAT32);
- recurrent_weights_ = AddInput(TensorType_FLOAT32);
+ weights_ = AddInput(weights);
+ recurrent_weights_ = AddInput(recurrent_weights);
bias_ = AddInput(TensorType_FLOAT32);
hidden_state_ = AddOutput(TensorType_FLOAT32);
output_ = AddOutput(TensorType_FLOAT32);
@@ -173,7 +224,7 @@ class RNNOpModel : public SingleOpModel {
int num_units() { return units_; }
int num_batches() { return batches_; }
- private:
+ protected:
int input_;
int weights_;
int recurrent_weights_;
@@ -186,53 +237,26 @@ class RNNOpModel : public SingleOpModel {
int input_size_;
};
-TEST(FullyConnectedOpTest, BlackBoxTest) {
+// The hybrid model has quantized weights and recurrent_weights.
+class HybridRNNOpModel : public RNNOpModel {
+ public:
+ HybridRNNOpModel(int batches, int units, int size)
+ : RNNOpModel(batches, units, size, TensorType_UINT8, TensorType_UINT8) {}
+
+ void SetWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(weights_, f);
+ }
+
+ void SetRecurrentWeights(std::initializer_list<float> f) {
+ SymmetricQuantizeAndPopulate(recurrent_weights_, f);
+ }
+};
+
+TEST(RnnOpTest, BlackBoxTest) {
RNNOpModel rnn(2, 16, 8);
- rnn.SetWeights(
- {0.461459, 0.153381, 0.529743, -0.00371218, 0.676267, -0.211346,
- 0.317493, 0.969689, -0.343251, 0.186423, 0.398151, 0.152399,
- 0.448504, 0.317662, 0.523556, -0.323514, 0.480877, 0.333113,
- -0.757714, -0.674487, -0.643585, 0.217766, -0.0251462, 0.79512,
- -0.595574, -0.422444, 0.371572, -0.452178, -0.556069, -0.482188,
- -0.685456, -0.727851, 0.841829, 0.551535, -0.232336, 0.729158,
- -0.00294906, -0.69754, 0.766073, -0.178424, 0.369513, -0.423241,
- 0.548547, -0.0152023, -0.757482, -0.85491, 0.251331, -0.989183,
- 0.306261, -0.340716, 0.886103, -0.0726757, -0.723523, -0.784303,
- 0.0354295, 0.566564, -0.485469, -0.620498, 0.832546, 0.697884,
- -0.279115, 0.294415, -0.584313, 0.548772, 0.0648819, 0.968726,
- 0.723834, -0.0080452, -0.350386, -0.272803, 0.115121, -0.412644,
- -0.824713, -0.992843, -0.592904, -0.417893, 0.863791, -0.423461,
- -0.147601, -0.770664, -0.479006, 0.654782, 0.587314, -0.639158,
- 0.816969, -0.337228, 0.659878, 0.73107, 0.754768, -0.337042,
- 0.0960841, 0.368357, 0.244191, -0.817703, -0.211223, 0.442012,
- 0.37225, -0.623598, -0.405423, 0.455101, 0.673656, -0.145345,
- -0.511346, -0.901675, -0.81252, -0.127006, 0.809865, -0.721884,
- 0.636255, 0.868989, -0.347973, -0.10179, -0.777449, 0.917274,
- 0.819286, 0.206218, -0.00785118, 0.167141, 0.45872, 0.972934,
- -0.276798, 0.837861, 0.747958, -0.0151566, -0.330057, -0.469077,
- 0.277308, 0.415818});
-
- rnn.SetBias({0.065691948, -0.69055247, 0.1107955, -0.97084129, -0.23957068,
- -0.23566568, -0.389184, 0.47481549, -0.4791103, 0.29931796,
- 0.10463274, 0.83918178, 0.37197268, 0.61957061, 0.3956964,
- -0.37609905});
-
- rnn.SetRecurrentWeights({0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
- 0.1});
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
rnn.ResetHiddenState();
const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
@@ -256,6 +280,35 @@ TEST(FullyConnectedOpTest, BlackBoxTest) {
}
}
+TEST(HybridRnnOpTest, BlackBoxTest) {
+ HybridRNNOpModel rnn(2, 16, 8);
+ rnn.SetWeights(rnn_weights);
+ rnn.SetBias(rnn_bias);
+ rnn.SetRecurrentWeights(rnn_recurrent_weights);
+
+ rnn.ResetHiddenState();
+ const int input_sequence_size = sizeof(rnn_input) / sizeof(float) /
+ (rnn.input_size() * rnn.num_batches());
+
+ for (int i = 0; i < input_sequence_size; i++) {
+ float* batch_start = rnn_input + i * rnn.input_size();
+ float* batch_end = batch_start + rnn.input_size();
+ rnn.SetInput(0, batch_start, batch_end);
+ rnn.SetInput(rnn.input_size(), batch_start, batch_end);
+
+ rnn.Invoke();
+
+ float* golden_start = rnn_golden_output + i * rnn.num_units();
+ float* golden_end = golden_start + rnn.num_units();
+ std::vector<float> expected;
+ expected.insert(expected.end(), golden_start, golden_end);
+ expected.insert(expected.end(), golden_start, golden_end);
+
+ EXPECT_THAT(rnn.GetOutput(), ElementsAreArray(ArrayFloatNear(
+ expected, /*max_abs_error=*/0.0104)));
+ }
+}
+
} // namespace
} // namespace tflite