diff options
Diffstat (limited to 'tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc')
-rw-r--r-- | tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc index 02293796e4..2e632e185d 100644 --- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc @@ -18,7 +18,9 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "tensorflow/compiler/xla/client/lib/arithmetic.h" +#include "tensorflow/compiler/xla/client/lib/constants.h" #include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" +#include "tensorflow/compiler/xla/client/xla_computation.h" #include "tensorflow/core/platform/macros.h" namespace tensorflow { @@ -50,8 +52,8 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { } else { const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type); const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type); - min_range = ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin); - max_range = ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax); + min_range = ReduceAll(input, xla::MaxValue(b, xla_type), *fmin); + max_range = ReduceAll(input, xla::MinValue(b, xla_type), *fmax); } xla::XlaOp num_bits; @@ -93,10 +95,10 @@ class QuantizeAndDequantizeOp : public XlaOpKernel { // while keeping 0 unchanged. xla::XlaOp scale_from_min_side = Select(Gt(min_quantized * min_range, zero), min_quantized / min_range, - XlaHelpers::MaxFiniteValue(b, data_type)); + xla::MaxFiniteValue(b, xla_type)); xla::XlaOp scale_from_max_side = Select(Gt(max_quantized * max_range, zero), max_quantized / max_range, - XlaHelpers::MaxFiniteValue(b, data_type)); + xla::MaxFiniteValue(b, xla_type)); // Note: Avoids changing the side of the range that determines scale. xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side); |