diff options
author | 2018-05-23 09:16:52 -0700 | |
---|---|---|
committer | 2018-05-23 09:19:36 -0700 | |
commit | 7a82d0fd10901f4b59f38e838a24a04df8305f73 (patch) | |
tree | ffb2949fa729e90c29f0fe81be57224c00f1abb3 /tensorflow/contrib/lite/kernels/l2norm_test.cc | |
parent | d1f44e1c60d38cc36bc438b59338c3a4eecf0615 (diff) |
Support batch size > 1 in L2Normalization 8 bit quantized implementations.
PiperOrigin-RevId: 197736184
Diffstat (limited to 'tensorflow/contrib/lite/kernels/l2norm_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/l2norm_test.cc | 30 |
1 files changed, 28 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc index 11cc666bad..070ed60040 100644 --- a/tensorflow/contrib/lite/kernels/l2norm_test.cc +++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc @@ -67,7 +67,7 @@ class L2NormOpModel : public SingleOpModel { int output_; }; -TEST(L2NormOpTest, SimpleTest) { +TEST(L2NormOpTest, SimpleFloatTest) { L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1}); @@ -76,7 +76,7 @@ TEST(L2NormOpTest, SimpleTest) { ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05})); } -TEST(L2NormOpTest, MultipleBatchesTest) { +TEST(L2NormOpTest, MultipleBatchFloatTest) { L2NormOpModel m({3, 1, 1, 6}, TensorType_FLOAT32, ActivationFunctionType_NONE); m.SetInput({ @@ -105,6 +105,32 @@ TEST(L2NormOpTest, SimpleUint8Test) { ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1))); } +TEST(L2NormOpTest, MultipleBatchUint8Test) { + L2NormOpModel m({3, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE); + + m.QuantizeAndPopulate<uint8_t>(m.input(), + { + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 1 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 2 + -1.1, 0.6, 0.7, 1.2, -0.7, 0.1, // batch 3 + }); + m.Invoke(); + EXPECT_THAT(m.GetOutput<uint8_t>(), + ElementsAreArray({ + 58, 166, 173, 205, 83, 134, // batch 1 + 58, 166, 173, 205, 83, 134, // batch 2 + 58, 166, 173, 205, 83, 134, // batch 3 + })); + EXPECT_THAT(m.GetDequantizedOutput(), + ElementsAreArray(ArrayFloatNear( + { + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 1 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 2 + -0.55, 0.3, 0.35, 0.6, -0.35, 0.05, // batch 3 + }, + 0.1))); +} + } // namespace } // namespace tflite |