aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/l2norm_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-04 14:55:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-04 14:57:47 -0700
commita9cb0e19b9d96935c653f9cf89cebb6407564e5b (patch)
treeca4e99b7f64a2e5ff29a64c0226802b894f8d97e /tensorflow/contrib/lite/kernels/l2norm_test.cc
parente8882f768127b71e03efbf193a9c3152ab84802a (diff)
Add quantized uint8 L2Normalization Kernel.
PiperOrigin-RevId: 191652174
Diffstat (limited to 'tensorflow/contrib/lite/kernels/l2norm_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/l2norm_test.cc49
1 files changed, 42 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/kernels/l2norm_test.cc b/tensorflow/contrib/lite/kernels/l2norm_test.cc
index 30e103f330..042314ccf5 100644
--- a/tensorflow/contrib/lite/kernels/l2norm_test.cc
+++ b/tensorflow/contrib/lite/kernels/l2norm_test.cc
@@ -25,10 +25,22 @@ using ::testing::ElementsAreArray;
class L2NormOpModel : public SingleOpModel {
public:
- L2NormOpModel(std::initializer_list<int> input_shape,
- ActivationFunctionType activation_type) {
- input_ = AddInput(TensorType_FLOAT32);
- output_ = AddOutput(TensorType_FLOAT32);
+ L2NormOpModel(const std::initializer_list<int> input_shape,
+ const TensorType tensor_type,
+ const ActivationFunctionType activation_type) {
+ TensorData data = TensorData{tensor_type};
+ if (tensor_type != TensorType_FLOAT32) {
+ data.min = -2.0;
+ data.max = 2.0;
+ data.scale = 2.0;
+ data.zero_point = 128;
+ }
+ input_ = AddInput(data);
+ if (tensor_type != TensorType_FLOAT32) {
+ data.min = -1.0;
+ data.max = 127.0 / 128.0;
+ }
+ output_ = AddOutput(data);
SetBuiltinOp(BuiltinOperator_L2_NORMALIZATION, BuiltinOptions_L2NormOptions,
CreateL2NormOptions(builder_, activation_type).Union());
BuildInterpreter({input_shape});
@@ -38,7 +50,17 @@ class L2NormOpModel : public SingleOpModel {
PopulateTensor(input_, data);
}
- std::vector<float> GetOutput() { return ExtractVector<float>(output_); }
+ template <typename T>
+ std::vector<T> GetOutput() {
+ return ExtractVector<T>(output_);
+ }
+
+ std::vector<float> GetDequantizedOutput() {
+ return Dequantize<uint8_t>(ExtractVector<uint8_t>(output_),
+ GetScale(output_), GetZeroPoint(output_));
+ }
+
+ int input() const { return input_; }
private:
int input_;
@@ -46,13 +68,26 @@ class L2NormOpModel : public SingleOpModel {
};
TEST(L2NormOpTest, SimpleTest) {
- L2NormOpModel m({1, 1, 1, 6}, ActivationFunctionType_NONE);
+ L2NormOpModel m({1, 1, 1, 6}, TensorType_FLOAT32,
+ ActivationFunctionType_NONE);
m.SetInput({-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
m.Invoke();
- EXPECT_THAT(m.GetOutput(),
+ EXPECT_THAT(m.GetOutput<float>(),
ElementsAreArray({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}));
}
+TEST(L2NormOpTest, SimpleUint8Test) {
+ L2NormOpModel m({1, 1, 1, 6}, TensorType_UINT8, ActivationFunctionType_NONE);
+
+ m.QuantizeAndPopulate<uint8_t>(m.input(), {-1.1, 0.6, 0.7, 1.2, -0.7, 0.1});
+ m.Invoke();
+ EXPECT_THAT(m.GetOutput<uint8_t>(),
+ ElementsAreArray({58, 166, 173, 205, 83, 134}));
+ EXPECT_THAT(m.GetDequantizedOutput(),
+ ElementsAreArray(
+ ArrayFloatNear({-0.55, 0.3, 0.35, 0.6, -0.35, 0.05}, 0.1)));
+}
+
} // namespace
} // namespace tflite