aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/l2norm_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-23 09:16:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 09:19:36 -0700
commit7a82d0fd10901f4b59f38e838a24a04df8305f73 (patch)
treeffb2949fa729e90c29f0fe81be57224c00f1abb3 /tensorflow/contrib/lite/kernels/l2norm_test.cc
parentd1f44e1c60d38cc36bc438b59338c3a4eecf0615 (diff)
Support batch size > 1 in L2Normalization 8 bit quantized implementations.
PiperOrigin-RevId: 197736184
Diffstat (limited to 'tensorflow/contrib/lite/kernels/l2norm_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc30
1 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
index 11cc666bad..070ed60040 100644
--- a/tensorflow/contrib/lite/kernels/l2norm_test.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel {
int output_;
};
-TEST(L2NormOpTest, SimpleTest) {
+TEST(L2NormOpTest, SimpleFloatTest) {
L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
@@ -76,7 +76,7 @@ TEST(L2NormOpTest, SimpleTest) {
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
-TEST(L2NormOpTest, MultipleBatchesTest) {
+TEST(L2NormOpTest, MultipleBatchFloatTest) {
L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32,
ActivationFunctionType_NONE);
m.SetInput({
@@ -105,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) {
ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
}
+TEST(L2NormOpTest, MultipleBatchUint8Test) {
+ L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+ m.QuantizeAndPopulate<uint8_t>(m.input(),
+ {
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2
+ -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3
+ });
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({
+ 58, 166, 173, 205, 83, 134, // batch 1
+ 58, 166, 173, 205, 83, 134, // batch 2
+ 58, 166, 173, 205, 83, 134, // batch 3
+ }));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(ArrayFloatNear(
+ {
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2
+ -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3
+ },
+ 0.1)));
+}
+
} // namespace
} // namespace tflite