diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-04 14:18:22 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-04 14:23:27 -0700 |
commit | a2e48d849f5c7a97b788ba8d2499e95aaef95945 (patch) | |
tree | 1d90f19c64d57f513735948adbd0015939621823 /tensorflow/contrib | |
parent | 4c1da53840fed235409cb2c571ea081e28388f75 (diff) |
Fix problem in quantized version of Comparison op handler
PiperOrigin-RevId: 215801773
Diffstat (limited to 'tensorflow/contrib')
-rw-r--r-- | tensorflow/contrib/lite/kernels/comparisons.cc | 16 | ||||
-rw-r--r-- | tensorflow/contrib/lite/kernels/comparisons_test.cc | 11 |
2 files changed, 16 insertions, 11 deletions
diff --git a/tensorflow/contrib/lite/kernels/comparisons.cc b/tensorflow/contrib/lite/kernels/comparisons.cc index f765235e04..3926af5b97 100644 --- a/tensorflow/contrib/lite/kernels/comparisons.cc +++ b/tensorflow/contrib/lite/kernels/comparisons.cc @@ -66,31 +66,25 @@ TfLiteStatus ComparisonPrepare(TfLiteContext* context, TfLiteNode* node) { if (input1->type == kTfLiteUInt8) { \ auto input1_offset = -input1->params.zero_point; \ auto input2_offset = -input2->params.zero_point; \ - const int left_shift = 20; \ - const double twice_max_input_scale = \ - 2 * std::max(input1->params.scale, input2->params.scale); \ - const double real_input1_multiplier = \ - input1->params.scale / twice_max_input_scale; \ - const double real_input2_multiplier = \ - input2->params.scale / twice_max_input_scale; \ + const int left_shift = 8; \ \ int32 input1_multiplier; \ int input1_shift; \ - QuantizeMultiplierSmallerThanOneExp(real_input1_multiplier, \ + QuantizeMultiplierSmallerThanOneExp(input1->params.scale, \ &input1_multiplier, &input1_shift); \ int32 input2_multiplier; \ int input2_shift; \ - QuantizeMultiplierSmallerThanOneExp(real_input2_multiplier, \ + QuantizeMultiplierSmallerThanOneExp(input2->params.scale, \ &input2_multiplier, &input2_shift); \ \ ComparisonParams op_params; \ op_params.left_shift = left_shift; \ op_params.input1_offset = input1_offset; \ op_params.input1_multiplier = input1_multiplier; \ - op_params.input1_shift = -input1_shift; \ + op_params.input1_shift = input1_shift; \ op_params.input2_offset = input2_offset; \ op_params.input2_multiplier = input2_multiplier; \ - op_params.input2_shift = -input2_shift; \ + op_params.input2_shift = input2_shift; \ if (requires_broadcast) { \ reference_ops::Broadcast4DSlow##opname##WithScaling( \ op_params, GetTensorShape(input1), GetTensorData<uint8_t>(input1), \ diff --git a/tensorflow/contrib/lite/kernels/comparisons_test.cc b/tensorflow/contrib/lite/kernels/comparisons_test.cc index 67a91c17fd..04c8bf2e30 100644 --- a/tensorflow/contrib/lite/kernels/comparisons_test.cc +++ b/tensorflow/contrib/lite/kernels/comparisons_test.cc @@ -402,6 +402,17 @@ TEST(ComparisonsTest, GreaterQuantized) { EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); } +TEST(ComparisonsTest, GreaterQuantizedSmallRange) { + ComparisonOpModel model({TensorType_UINT8, {1, 2, 2, 1}, 0.0, 1.0}, + {TensorType_UINT8, {1, 2, 2, 1}, 0.0, 2.0}, + TensorType_UINT8, BuiltinOperator_GREATER); + model.QuantizeAndPopulate<uint8_t>(model.input1(), {1.0, 0.5, 0.35, 0.1}); + model.QuantizeAndPopulate<uint8_t>(model.input2(), {1.01, 0.25, 0.3, 0.4}); + model.Invoke(); + + EXPECT_THAT(model.GetOutput(), ElementsAre(false, true, true, false)); +} + TEST(ComparisonsTest, GreaterEqualQuantized) { const float kMin = -1.f; const float kMax = 128.f; |