aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-04 14:18:22 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 14:23:27 -0700
commita2e48d849f5c7a97b788ba8d2499e95aaef95945 (patch)
tree1d90f19c64d57f513735948adbd0015939621823 /tensorflow/contrib
parent4c1da53840fed235409cb2c571ea081e28388f75 (diff)
Fix problem in quantized version of Comparison op handler
PiperOrigin-RevId: 215801773
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons.cc16
-rw-r--r--tensorflow/contrib/lite/kernels/comparisons_test.cc11
2 files changed, 16 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc
index f765235e04..3926af5b97 100644
--- a/tensorflow/contrib/lite/kernels/comparisons.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons.cc
@@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) {
if (input1->type == kTfLiteUInt8) { \
auto input1_offset = -input1->params.zero_point; \
auto input2_offset = -input2->params.zero_point; \
- const int left_shift = 20; \
- const double twice_max_input_scale = \
- 2 * std::max(input1->params.scale, input2->params.scale); \
- const double real_input1_multiplier = \
- input1->params.scale / twice_max_input_scale; \
- const double real_input2_multiplier = \
- input2->params.scale / twice_max_input_scale; \
+ const int left_shift = 8; \
\
int32 input1_multiplier; \
int input1_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \
&input1_multiplier, &input1_shift); \
int32 input2_multiplier; \
int input2_shift; \
- QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \
+ QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \
&input2_multiplier, &input2_shift); \
\
ComparisonParams op_params; \
op_params.left_shift = left_shift; \
op_params.input1_offset = input1_offset; \
op_params.input1_multiplier = input1_multiplier; \
- op_params.input1_shift = -input1_shift; \
+ op_params.input1_shift = input1_shift; \
op_params.input2_offset = input2_offset; \
op_params.input2_multiplier = input2_multiplier; \
- op_params.input2_shift = -input2_shift; \
+ op_params.input2_shift = input2_shift; \
if (requires_broadcast) { \
reference_ops::Broadcast4DSlow##opname##WithScaling( \
op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \
diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc
index 67a91c17fd..04c8bf2e30 100644
--- a/tensorflow/contrib/lite/kernels/comparisons_test.cc
+++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc
@@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) {
EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
}
+TEST(ComparisonsTest, GreaterQuantizedSmallRange) {
+ ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0},
+ {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0},
+ TensorType_UINT8, BuiltinOperator_GREATER);
+ model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1});
+ model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4});
+ model.Invoke();
+
+ EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false));
+}
+
TEST(ComparisonsTest, GreaterEqualQuantized) {
const float kMin = -1.f;
const float kMax = 128.f;