aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Peter Hawkins <phawkins@google.com>2018-06-27 17:03:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-27 17:07:13 -0700
commit145cc9060896fb1f1f17eb15fc90fd6fbac959f2 (patch)
tree82e3f558e42f01c2831b7b921a1469e58923899a
parentb619ef976190f88fed7e18b07ed8ff5413d27bcb (diff)
[TF:XLA] Implement QuantizeAndDequantizeV3.
Change XLA lowering of QuantizeAndDequantizeV2/V3 to match the TF kernel much more closely. The main exception is the min_quantized and max_quantized values are calculated as floats to avoid the need for 64-bit integer math, which is not present on all accelerators. Reformats unary_ops_test.py in passing, but on the whole I don't mind the reformatting. PiperOrigin-RevId: 202395114
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py256
-rw-r--r--tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc136
2 files changed, 214 insertions, 178 deletions
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index e610b63e30..a24abd7547 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -47,8 +47,13 @@ def nhwc_to_format(x, data_format):
class UnaryOpsTest(XLATestCase):
"""Test cases for unary operators."""
- def _assertOpOutputMatchesExpected(self, op, inp, expected,
- equality_test=None, rtol=1e-3, atol=1e-5):
+ def _assertOpOutputMatchesExpected(self,
+ op,
+ inp,
+ expected,
+ equality_test=None,
+ rtol=1e-3,
+ atol=1e-5):
"""Verifies that 'op' produces 'expected' when fed input 'inp' .
Args:
@@ -81,10 +86,10 @@ class UnaryOpsTest(XLATestCase):
def testAllTypeOps(self):
for dtype in self.numeric_types:
self._assertOpOutputMatchesExpected(
- array_ops.diag,
- np.array([1, 2, 3, 4], dtype=dtype),
- np.array([[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
- dtype=dtype))
+ array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
+ np.array(
+ [[1, 0, 0, 0], [0, 2, 0, 0], [0, 0, 3, 0], [0, 0, 0, 4]],
+ dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.diag_part,
np.arange(36).reshape([2, 3, 2, 3]).astype(dtype),
@@ -102,8 +107,7 @@ class UnaryOpsTest(XLATestCase):
expected=np.array([[-1, 1]], dtype=dtype))
self._assertOpOutputMatchesExpected(
- array_ops.matrix_diag,
- np.array([[1, 2], [3, 4]], dtype=dtype),
+ array_ops.matrix_diag, np.array([[1, 2], [3, 4]], dtype=dtype),
np.array([[[1, 0], [0, 2]], [[3, 0], [0, 4]]], dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag, np.array([1, 2, 3, 4], dtype=dtype),
@@ -115,10 +119,10 @@ class UnaryOpsTest(XLATestCase):
np.array(
[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=dtype),
np.array(
- [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]],
- [[4, 0, 0], [0, 5, 0], [0, 0, 6]]],
- [[[7, 0, 0], [0, 8, 0], [0, 0, 9]],
- [[10, 0, 0], [0, 11, 0], [0, 0, 12]]]],
+ [[[[1, 0, 0], [0, 2, 0], [0, 0, 3]], [[4, 0, 0], [0, 5, 0], [
+ 0, 0, 6
+ ]]], [[[7, 0, 0], [0, 8, 0], [0, 0, 9]], [[10, 0, 0], [0, 11, 0],
+ [0, 0, 12]]]],
dtype=dtype))
self._assertOpOutputMatchesExpected(
array_ops.matrix_diag_part,
@@ -159,36 +163,30 @@ class UnaryOpsTest(XLATestCase):
continue
x = np.arange(-0.90, 0.90, 0.25)
self._assertOpOutputMatchesExpected(
- math_ops.acos,
- x.astype(dtype),
- expected=np.arccos(x).astype(dtype))
+ math_ops.acos, x.astype(dtype), expected=np.arccos(x).astype(dtype))
self._assertOpOutputMatchesExpected(
- math_ops.asin,
- x.astype(dtype),
- expected=np.arcsin(x).astype(dtype))
+ math_ops.asin, x.astype(dtype), expected=np.arcsin(x).astype(dtype))
x = np.arange(-3, 3).reshape(1, 3, 2)
self._assertOpOutputMatchesExpected(
- math_ops.atan,
- x.astype(dtype),
- expected=np.arctan(x).astype(dtype))
+ math_ops.atan, x.astype(dtype), expected=np.arctan(x).astype(dtype))
self._assertOpOutputMatchesExpected(
math_ops.acosh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([0, 1.3169579, 1.76274717, 2.06343707],
- dtype=dtype))
+ expected=np.array(
+ [0, 1.3169579, 1.76274717, 2.06343707], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.asinh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([0.88137359, 1.44363548, 1.81844646, 2.09471255],
- dtype=dtype))
+ expected=np.array(
+ [0.88137359, 1.44363548, 1.81844646, 2.09471255], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.atanh,
np.array([0.1, 0.2, 0.3, 0.4], dtype=dtype),
- expected=np.array([0.10033535, 0.20273255, 0.3095196, 0.42364893],
- dtype=dtype))
+ expected=np.array(
+ [0.10033535, 0.20273255, 0.3095196, 0.42364893], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.ceil,
@@ -198,8 +196,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.cosh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.54308063, 3.76219569, 10.067662, 27.30823284],
- dtype=dtype))
+ expected=np.array(
+ [1.54308063, 3.76219569, 10.067662, 27.30823284], dtype=dtype))
# Disable float16 testing for now
if dtype != np.float16:
@@ -229,8 +227,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 1, 1, 1, 1, 1, 1, 0, 0]], dtype=np.bool))
# Tests for tf.nn ops.
@@ -271,16 +269,20 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.rint,
- np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
- [0.5, 1.5, 2.5, 3.5]], dtype=dtype),
- expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
- dtype=dtype))
+ np.array(
+ [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
+ [0.5, 1.5, 2.5, 3.5]],
+ dtype=dtype),
+ expected=np.array(
+ [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.round,
- np.array([[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
- [0.5, 1.5, 2.5, 3.5]], dtype=dtype),
- expected=np.array([[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]],
- dtype=dtype))
+ np.array(
+ [[-1.7, 1.2, 4.0, 0.0], [-3.5, -2.5, -1.5, -0.5],
+ [0.5, 1.5, 2.5, 3.5]],
+ dtype=dtype),
+ expected=np.array(
+ [[-2, 1, 4, 0], [-4, -2, -2, 0], [0, 2, 2, 4]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.rsqrt,
@@ -289,10 +291,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.sigmoid,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.7310586, 0.7310586, 0.7310586, 0.7310586],
[0.7310586, 0.880797, 0.95257413, 0.98201376]],
@@ -306,8 +305,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.sinh,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.17520119, 3.62686041, 10.01787493, 27.2899172],
- dtype=dtype))
+ expected=np.array(
+ [1.17520119, 3.62686041, 10.01787493, 27.2899172], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.sqrt,
@@ -317,15 +316,12 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
math_ops.tan,
np.array([1, 2, 3, 4], dtype=dtype),
- expected=np.array([1.55740772, -2.18503986, -0.14254654, 1.15782128],
- dtype=dtype))
+ expected=np.array(
+ [1.55740772, -2.18503986, -0.14254654, 1.15782128], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.tanh,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.76159418, 0.76159418, 0.76159418, 0.76159418],
[0.76159418, 0.96402758, 0.99505478, 0.99932933]],
@@ -333,10 +329,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.log_softmax,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[-1.3862944, -1.3862944, -1.3862944, -1.3862944],
[-3.4401896, -2.4401896, -1.4401897, -0.44018969]],
@@ -370,10 +363,7 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.softmax,
- np.array(
- [[1, 1, 1, 1],
- [1, 2, 3, 4]],
- dtype=dtype),
+ np.array([[1, 1, 1, 1], [1, 2, 3, 4]], dtype=dtype),
expected=np.array(
[[0.25, 0.25, 0.25, 0.25],
[0.032058604, 0.087144323, 0.23688284, 0.64391428]],
@@ -382,8 +372,8 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
nn_ops.softsign,
np.array([[-2, -1, 0, 1, 2]], dtype=dtype),
- expected=np.array([[-0.66666669, -0.5, 0, 0.5, 0.66666669]],
- dtype=dtype))
+ expected=np.array(
+ [[-0.66666669, -0.5, 0, 0.5, 0.66666669]], dtype=dtype))
self._assertOpOutputMatchesExpected(
math_ops.is_finite,
@@ -392,10 +382,23 @@ class UnaryOpsTest(XLATestCase):
expected=np.array(
[[True, False, True], [False, True, True]], dtype=np.bool))
+ def quantize_and_dequantize_v2(x):
+ return array_ops.quantize_and_dequantize_v2(
+ x, -127, 127, signed_input=True, num_bits=8)
+
+ self._assertOpOutputMatchesExpected(
+ quantize_and_dequantize_v2,
+ np.array([-1, -0.5, 0, 0.3], dtype=dtype),
+ expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
+
+ def quantize_and_dequantize_v3(x):
+ return array_ops.quantize_and_dequantize_v3(
+ x, -127, 127, num_bits=8, signed_input=True, range_given=False)
+
self._assertOpOutputMatchesExpected(
- lambda x: array_ops.quantize_and_dequantize_v2(x, -127, 127, True, 8),
+ quantize_and_dequantize_v3,
np.array([-1, -0.5, 0, 0.3], dtype=dtype),
- expected=np.array([-1, -64.0 / 127, 0, 38.0 / 127], dtype=dtype))
+ expected=np.array([-1., -0.5, 0., 0.296875], dtype=dtype))
def testComplexOps(self):
for dtype in self.complex_types:
@@ -576,13 +579,13 @@ class UnaryOpsTest(XLATestCase):
for dtype in self.float_types:
self._assertOpOutputMatchesExpected(
math_ops.is_inf,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[1, 0, 0, 0, 0, 0, 0, 1, 0]], dtype=np.bool))
self._assertOpOutputMatchesExpected(
math_ops.is_nan,
- np.array([[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]],
- dtype=dtype),
+ np.array(
+ [[np.NINF, -2, -1, 0, 0.5, 1, 2, np.inf, np.nan]], dtype=dtype),
expected=np.array([[0, 0, 0, 0, 0, 0, 0, 0, 1]], dtype=np.bool))
def testLogicalOps(self):
@@ -599,14 +602,15 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
lambda x: gen_nn_ops.bias_add_grad(x, data_format="NCHW"),
- np.array([[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]],
- dtype=np.float32),
+ np.array(
+ [[[1., 2.], [3., 4.]], [[5., 6.], [7., 8.]]], dtype=np.float32),
expected=np.array([10., 26.], dtype=np.float32))
def testCast(self):
shapes = [[], [4], [2, 3], [2, 0, 4]]
- types = (set([dtypes.bool, dtypes.int32, dtypes.float32]) |
- self.complex_tf_types)
+ types = (
+ set([dtypes.bool, dtypes.int32, dtypes.float32])
+ | self.complex_tf_types)
for shape in shapes:
for src_type in types:
for dst_type in types:
@@ -648,14 +652,11 @@ class UnaryOpsTest(XLATestCase):
self._assertOpOutputMatchesExpected(
rank_op, dtype(7), expected=np.int32(0))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [[], []], dtype=dtype), expected=np.int32(2))
+ rank_op, np.array([[], []], dtype=dtype), expected=np.int32(2))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [-1, 1], dtype=dtype), expected=np.int32(1))
+ rank_op, np.array([-1, 1], dtype=dtype), expected=np.int32(1))
self._assertOpOutputMatchesExpected(
- rank_op, np.array(
- [[-1, 1]], dtype=dtype), expected=np.int32(2))
+ rank_op, np.array([[-1, 1]], dtype=dtype), expected=np.int32(2))
self._assertOpOutputMatchesExpected(
rank_op,
np.array([[-1], [1], [4]], dtype=dtype),
@@ -720,97 +721,97 @@ class UnaryOpsTest(XLATestCase):
equality_test=self.ListsAreClose)
def testDepthToSpace(self):
+
def make_op(data_format):
+
def op(x):
- return array_ops.depth_to_space(x, block_size=2,
- data_format=data_format)
+ return array_ops.depth_to_space(
+ x, block_size=2, data_format=data_format)
+
return op
for dtype in self.numeric_types:
for data_format in ["NCHW", "NHWC"]:
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype),
- data_format),
- expected=nhwc_to_format(np.array([[[[1], [2]],
- [[3], [4]]]], dtype=dtype),
- data_format))
+ nhwc_to_format(
+ np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format),
+ expected=nhwc_to_format(
+ np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
nhwc_to_format(
- np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]],
- dtype=dtype),
+ np.array(
+ [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
- dtype=dtype),
- data_format))
+ np.array(
+ [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
+ dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
nhwc_to_format(
- np.array([[[[1, 2, 3, 4],
- [5, 6, 7, 8]],
- [[9, 10, 11, 12],
- [13, 14, 15, 16]]]], dtype=dtype),
- data_format),
+ np.array(
+ [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
+ [13, 14, 15, 16]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]], dtype=dtype),
- data_format))
+ np.array(
+ [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
+ dtype=dtype), data_format))
def testSpaceToDepth(self):
+
def make_op(data_format):
+
def op(x):
- return array_ops.space_to_depth(x, block_size=2,
- data_format=data_format)
+ return array_ops.space_to_depth(
+ x, block_size=2, data_format=data_format)
+
return op
for dtype in self.numeric_types:
for data_format in ["NCHW", "NHWC"]:
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1], [2]],
- [[3], [4]]]], dtype=dtype),
- data_format),
- expected=nhwc_to_format(np.array([[[[1, 2, 3, 4]]]], dtype=dtype),
- data_format))
+ nhwc_to_format(
+ np.array([[[[1], [2]], [[3], [4]]]], dtype=dtype), data_format),
+ expected=nhwc_to_format(
+ np.array([[[[1, 2, 3, 4]]]], dtype=dtype), data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1, 2, 3], [4, 5, 6]],
- [[7, 8, 9], [10, 11, 12]]]], dtype=dtype),
- data_format),
+ nhwc_to_format(
+ np.array(
+ [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]],
- dtype=dtype),
+ np.array(
+ [[[[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]]]], dtype=dtype),
data_format))
self._assertOpOutputMatchesExpected(
make_op(data_format),
- nhwc_to_format(np.array([[[[1], [2], [5], [6]],
- [[3], [4], [7], [8]],
- [[9], [10], [13], [14]],
- [[11], [12], [15], [16]]]], dtype=dtype),
- data_format),
+ nhwc_to_format(
+ np.array(
+ [[[[1], [2], [5], [6]], [[3], [4], [7], [8]],
+ [[9], [10], [13], [14]], [[11], [12], [15], [16]]]],
+ dtype=dtype), data_format),
expected=nhwc_to_format(
- np.array([[[[1, 2, 3, 4],
- [5, 6, 7, 8]],
- [[9, 10, 11, 12],
- [13, 14, 15, 16]]]], dtype=dtype),
- data_format))
+ np.array(
+ [[[[1, 2, 3, 4], [5, 6, 7, 8]], [[9, 10, 11, 12],
+ [13, 14, 15, 16]]]],
+ dtype=dtype), data_format))
def _assertSoftplusMatchesExpected(self, features, dtype):
features = np.array(features, dtype=dtype)
zero = np.asarray(0).astype(dtype)
expected = np.logaddexp(zero, features)
self._assertOpOutputMatchesExpected(
- nn_ops.softplus, features, expected=expected,
- rtol=1e-6,
- atol=9.1e-6)
+ nn_ops.softplus, features, expected=expected, rtol=1e-6, atol=9.1e-6)
def testSoftplus(self):
for dtype in self.float_types:
@@ -824,9 +825,10 @@ class UnaryOpsTest(XLATestCase):
one = dtype(1)
ten = dtype(10)
self._assertSoftplusMatchesExpected([
- log_eps, log_eps - one, log_eps + one, log_eps - ten,
- log_eps + ten, -log_eps, -log_eps - one, -log_eps + one,
- -log_eps - ten, -log_eps + ten], dtype)
+ log_eps, log_eps - one, log_eps + one, log_eps - ten, log_eps + ten,
+ -log_eps, -log_eps - one, -log_eps + one, -log_eps - ten,
+ -log_eps + ten
+ ], dtype)
if __name__ == "__main__":
diff --git a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
index 9576354c5f..02293796e4 100644
--- a/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/quantize_and_dequantize_op.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -29,82 +30,115 @@ class QuantizeAndDequantizeOp : public XlaOpKernel {
: XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("signed_input", &signed_input_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("range_given", &range_given_));
- OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
- OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
- errors::InvalidArgument("num_bits is out of range: ", num_bits_,
- " with signed_input_ ", signed_input_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::XlaOp input = ctx->Input(0);
const DataType data_type = ctx->input_type(0);
- // Comments taken from semantics description at
- // https://www.tensorflow.org/versions/r1.0/api_docs/cc/class/tensorflow/ops/quantize-and-dequantize
- //
- // ... we find m such that
- //
- // m = max(abs(input_min), abs(input_max)) if range_given is true,
- // m = max(abs(min_elem(input)),
- // abs(max_elem(input))) otherwise.
+ xla::PrimitiveType xla_type;
+ OP_REQUIRES_OK(ctx, DataTypeToPrimitiveType(data_type, &xla_type));
+
xla::XlaBuilder* b = ctx->builder();
- xla::XlaOp input_min, input_max;
+
+ // The implementation follows
+ // tensorflow/core/kernels/quantize_and_dequantize_op.h closely.
+ xla::XlaOp min_range, max_range;
if (range_given_) {
- double input_min_value, input_max_value;
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(1, &input_min_value));
- OP_REQUIRES_OK(ctx, ctx->ConstantInputAsFloatScalar(2, &input_max_value));
- input_min = XlaHelpers::FloatLiteral(b, data_type, input_min_value);
- input_max = XlaHelpers::FloatLiteral(b, data_type, input_max_value);
+ min_range = ctx->Input(1);
+ max_range = ctx->Input(2);
} else {
const xla::XlaComputation* fmax = ctx->GetOrCreateMax(data_type);
const xla::XlaComputation* fmin = ctx->GetOrCreateMin(data_type);
- input_min =
- xla::ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
- input_max =
- xla::ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
+ min_range = ReduceAll(input, XlaHelpers::MaxValue(b, data_type), *fmin);
+ max_range = ReduceAll(input, XlaHelpers::MinValue(b, data_type), *fmax);
}
- xla::XlaOp m = xla::Max(xla::Abs(input_min), xla::Abs(input_max));
-
- // Next, we choose our fixed-point quantization buckets, [min_fixed,
- // max_fixed]. If signed_input is true, this is
- //
- // [min_fixed, max_fixed ] = [-((1 << (num_bits - 1)) - 1),
- // (1 << (num_bits - 1)) - 1].
- //
- // Otherwise, if signed_input is false, the fixed-point range is
- //
- // [min_fixed, max_fixed] = [0, (1 << num_bits) - 1].
- int64 min_fixed, max_fixed;
+
+ xla::XlaOp num_bits;
+ if (num_bits_ < 0) {
+ OP_REQUIRES(
+ ctx, ctx->num_inputs() == 4,
+ errors::Internal("Expected 4 inputs to QuantizeAndDequantize"));
+ num_bits = ctx->Input(3);
+ } else {
+ num_bits = xla::ConstantR0<int32>(b, num_bits_);
+ }
+
+ const xla::XlaOp zero = XlaHelpers::Zero(b, data_type);
+ const xla::XlaOp one = XlaHelpers::One(b, data_type);
+ const xla::XlaOp two = XlaHelpers::FloatLiteral(b, data_type, 2.0);
+ const xla::XlaOp half = XlaHelpers::FloatLiteral(b, data_type, 0.5);
+
+ // Calculate the range for the simulated integer quantization:
+ // e.g. [-128,127] for signed = true, num_bits = 8,
+ // or [0, 255] for signed = false, num_bits = 8.
+ // We do this in floating point for hardware that does not have 64-bit
+ // integer support.
+ xla::XlaOp min_quantized, max_quantized;
if (signed_input_) {
- min_fixed = -((1LL << (num_bits_ - 1)) - 1);
- max_fixed = (1LL << (num_bits_ - 1)) - 1;
+ min_quantized =
+ -Pow(two, ConvertElementType(num_bits - xla::ConstantR0<int32>(b, 1),
+ xla_type));
+ max_quantized =
+ Pow(two, ConvertElementType(num_bits - xla::ConstantR0<int32>(b, 1),
+ xla_type)) -
+ one;
} else {
- min_fixed = 0;
- max_fixed = (1LL << num_bits_) - 1;
+ min_quantized = zero;
+ max_quantized = Pow(two, ConvertElementType(num_bits, xla_type)) - one;
}
- // From this we compute our scaling factor, s:
- //
- // s = (max_fixed - min_fixed) / (2 * m).
- xla::XlaOp s =
- xla::Div(XlaHelpers::FloatLiteral(b, data_type, max_fixed - min_fixed),
- xla::Mul(XlaHelpers::FloatLiteral(b, data_type, 2.0), m));
+ // Determine the maximum scaling factor that would scale
+ // [min_range, max_range] to not exceed [min_quantized, max_quantized],
+ // while keeping 0 unchanged.
+ xla::XlaOp scale_from_min_side =
+ Select(Gt(min_quantized * min_range, zero), min_quantized / min_range,
+ XlaHelpers::MaxFiniteValue(b, data_type));
+ xla::XlaOp scale_from_max_side =
+ Select(Gt(max_quantized * max_range, zero), max_quantized / max_range,
+ XlaHelpers::MaxFiniteValue(b, data_type));
- // Now we can quantize and dequantize the elements of our tensor. An element
- // e is transformed into e':
- //
- // e' = (e * s).round_to_nearest() / s.
- xla::XlaOp result = xla::Div(xla::Round(xla::Mul(input, s)), s);
+ // Note: Avoids changing the side of the range that determines scale.
+ xla::XlaOp cond = Lt(scale_from_min_side, scale_from_max_side);
+ xla::XlaOp scale = Select(cond, scale_from_min_side, scale_from_max_side);
+ xla::XlaOp inverse_scale =
+ Select(cond, min_range / min_quantized, max_range / max_quantized);
+ min_range = Select(cond, min_range, min_quantized * inverse_scale);
+ max_range = Select(cond, max_quantized * inverse_scale, max_range);
+ if (range_given_) {
+ // Note: The clamping here is to avoid overflow in the quantized type.
+ // The semantics of the op does not guarantee to clamp to the specified
+ // min_range and max_range - because we may have changed either min_range
+ // or max_range.
+ // No need to clamp to min_range and max_range if range_given_ == false as
+ // in that case they were measured from the tensor.
+ input = Clamp(min_range, input, max_range);
+ }
+ xla::XlaOp result =
+ Floor((input - min_range) * scale + half) * inverse_scale + min_range;
ctx->SetOutput(0, result);
}
- int64 num_bits_;
+ protected:
+ int64 num_bits_ = -1;
bool signed_input_;
bool range_given_;
};
-REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeOp);
+class QuantizeAndDequantizeV2Op : public QuantizeAndDequantizeOp {
+ public:
+ explicit QuantizeAndDequantizeV2Op(OpKernelConstruction* ctx)
+ : QuantizeAndDequantizeOp(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("num_bits", &num_bits_));
+ OP_REQUIRES(ctx, num_bits_ > 0 && num_bits_ < (signed_input_ ? 62 : 63),
+ errors::InvalidArgument("num_bits is out of range: ", num_bits_,
+ " with signed_input_ ", signed_input_));
+ }
+};
+
+REGISTER_XLA_OP(Name("QuantizeAndDequantizeV2"), QuantizeAndDequantizeV2Op);
+REGISTER_XLA_OP(Name("QuantizeAndDequantizeV3"), QuantizeAndDequantizeOp);
} // namespace
} // namespace tensorflow