diff options
author | 2018-06-21 18:30:08 -0700 | |
---|---|---|
committer | 2018-06-21 18:35:36 -0700 | |
commit | db0e56a65173a571c03691c3f7e0720d25682d1e (patch) | |
tree | 6c91b287b02b778b8b4cd06a5fa0013fcc78da81 /tensorflow/contrib/lite/kernels/detection_postprocess_test.cc | |
parent | f76ac33d424a8b4f520e632d05ca6dcb7b5317dc (diff) |
Handle dequantization of anchors tensor in custom postprocessing op
PiperOrigin-RevId: 201622115
Diffstat (limited to 'tensorflow/contrib/lite/kernels/detection_postprocess_test.cc')
-rw-r--r-- | tensorflow/contrib/lite/kernels/detection_postprocess_test.cc | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc index e801c5ace3..4e0f8484a3 100644 --- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc +++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc @@ -178,9 +178,10 @@ TEST(DetectionPostprocessOpTest, FloatTest) { TEST(DetectionPostprocessOpTest, QuantizedTest) { BaseDetectionPostprocessOpModel m( {TensorType_UINT8, {1, 6, 4}, -1.0, 1.0}, - {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, {TensorType_FLOAT32, {6, 4}}, + {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, + {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}, - {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}}); + {TensorType_FLOAT32, {}}); // six boxes in center-size encoding std::vector<std::initializer_list<float>> inputs1 = { {0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0, @@ -192,9 +193,10 @@ TEST(DetectionPostprocessOpTest, QuantizedTest) { .2}}; m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[0]); // six anchors in center-size encoding - m.SetInput3<float>({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, - 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, - 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}); + std::vector<std::initializer_list<float>> inputs3 = { + {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, + 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}}; + m.QuantizeAndPopulate<uint8_t>(m.input3(), inputs3[0]); m.Invoke(); // detection_boxes // in center-size @@ -204,7 +206,7 @@ TEST(DetectionPostprocessOpTest, QuantizedTest) { m.GetOutput1<float>(), ElementsAreArray(ArrayFloatNear( {0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0}, - 1e-1))); + 3e-1))); // detection_classes std::vector<int> output_shape2 = m.GetOutputShape2(); EXPECT_THAT(output_shape2, ElementsAre(1, 3)); |