aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 18:30:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 18:35:36 -0700
commitdb0e56a65173a571c03691c3f7e0720d25682d1e (patch)
tree6c91b287b02b778b8b4cd06a5fa0013fcc78da81 /tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
parentf76ac33d424a8b4f520e632d05ca6dcb7b5317dc (diff)
Handle dequantization of anchors tensor in custom postprocessing op
PiperOrigin-RevId: 201622115
Diffstat (limited to 'tensorflow/contrib/lite/kernels/detection_postprocess_test.cc')
-rw-r--r--tensorflow/contrib/lite/kernels/detection_postprocess_test.cc14
1 files changed, 8 insertions, 6 deletions
diff --git a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
index e801c5ace3..4e0f8484a3 100644
--- a/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
+++ b/tensorflow/contrib/lite/kernels/detection_postprocess_test.cc
@@ -178,9 +178,10 @@ TEST(DetectionPostprocessOpTest, FloatTest) {
TEST(DetectionPostprocessOpTest, QuantizedTest) {
BaseDetectionPostprocessOpModel m(
{TensorType_UINT8, {1, 6, 4}, -1.0, 1.0},
- {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0}, {TensorType_FLOAT32, {6, 4}},
+ {TensorType_UINT8, {1, 6, 3}, 0.0, 1.0},
+ {TensorType_UINT8, {6, 4}, 0.0, 100.5}, {TensorType_FLOAT32, {}},
{TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}},
- {TensorType_FLOAT32, {}}, {TensorType_FLOAT32, {}});
+ {TensorType_FLOAT32, {}});
// six boxes in center-size encoding
std::vector<std::initializer_list<float>> inputs1 = {
{0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, -1.0, 0.0, 0.0,
@@ -192,9 +193,10 @@ TEST(DetectionPostprocessOpTest, QuantizedTest) {
.2}};
m.QuantizeAndPopulate<uint8_t>(m.input2(), inputs2[0]);
// six anchors in center-size encoding
- m.SetInput3<float>({0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
- 0.5, 0.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0,
- 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0});
+ std::vector<std::initializer_list<float>> inputs3 = {
+ {0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0, 0.5, 0.5, 1.0, 1.0,
+ 0.5, 10.5, 1.0, 1.0, 0.5, 10.5, 1.0, 1.0, 0.5, 100.5, 1.0, 1.0}};
+ m.QuantizeAndPopulate<uint8_t>(m.input3(), inputs3[0]);
m.Invoke();
// detection_boxes
// in center-size
@@ -204,7 +206,7 @@ TEST(DetectionPostprocessOpTest, QuantizedTest) {
m.GetOutput1<float>(),
ElementsAreArray(ArrayFloatNear(
{0.0, 10.0, 1.0, 11.0, 0.0, 0.0, 1.0, 1.0, 0.0, 100.0, 1.0, 101.0},
- 1e-1)));
+ 3e-1)));
// detection_classes
std::vector<int> output_shape2 = m.GetOutputShape2();
EXPECT_THAT(output_shape2, ElementsAre(1, 3));