aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dequantize_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dequantize_op.cc')
-rw-r--r--tensorflow/core/kernels/dequantize_op.cc26
1 files changed, 8 insertions, 18 deletions
diff --git a/tensorflow/core/kernels/dequantize_op.cc b/tensorflow/core/kernels/dequantize_op.cc
index 3f644a61bf..42fbf95cd3 100644
--- a/tensorflow/core/kernels/dequantize_op.cc
+++ b/tensorflow/core/kernels/dequantize_op.cc
@@ -96,27 +96,17 @@ class DequantizeOp : public OpKernel {
output);
}
} else if (mode_ == QUANTIZE_MODE_SCALED) {
- // The quantization logic for mode SCALED matches that of
- // QuantizeAndDequantizeV2 and QuantizeAndDequantizeV3.
- static constexpr int num_bits = sizeof(T) * 8;
- const float max_abs = std::max(std::abs(min_range), std::abs(max_range));
- bool is_signed = std::is_signed<T>::value;
- // If it is signed, we try to keep 0.0 being 0 and drop one bucket. For
- // example, if it is 8 bits, we have the range [-127, 127]. So for input
- // range of [-x, x], the scale should be 254/(2*x).
- //
- // If it is unsigned and num_bits == 8, the range with 8 bits is [0, 255].
- // If the input range is [0, x], then the scale is x/255 instead of 254 as
- // in the case above.
- const int target_bits = is_signed ? (num_bits - 1) : num_bits;
- const float target_range =
- static_cast<float>((uint64_t{1} << target_bits) - 1);
- const float scale_factor = max_abs / target_range;
+ // TODO(pauldonnelly): Update QuantizeAndDequantizeV2 and
+ // QuantizeAndDequantizeV3 to match this SCALED mode again.
+ const float scale_factor =
+ std::numeric_limits<T>::min() == 0
+ ? (max_range / std::numeric_limits<T>::max())
+ : std::max(min_range / std::numeric_limits<T>::min(),
+ max_range / std::numeric_limits<T>::max());
float* out_ptr = output->flat<float>().data();
const T* in_ptr = input.flat<T>().data();
-
const int64 num_elements = input.NumElements();
- for (int i = 0; i < num_elements; ++i) {
+ for (int64 i = 0; i < num_elements; ++i) {
out_ptr[i] = static_cast<int>(in_ptr[i]) * scale_factor;
}
}