diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-17 03:12:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 03:16:54 -0700 |
commit | cac963862be3faa421c559f39033c9bfb3b27a51 (patch) | |
tree | 8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/tests | |
parent | b1f4328517851e76cff3d4af8766e7e3446314ba (diff) |
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op
definitions in tensorflow are really inconsistent. I tried to infer whether the
limitation is on signed types, index types or just arbitrary. In the latter
case just int8/uint8 is blacklisted, we should probably lift that requirement
at some point.
PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r-- | tensorflow/compiler/tests/argminmax_test.py | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tests/binary_ops_test.py | 11 | ||||
-rw-r--r-- | tensorflow/compiler/tests/build_defs.bzl | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tests/random_ops_test.py | 3 | ||||
-rw-r--r-- | tensorflow/compiler/tests/reverse_sequence_op_test.py | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tests/unary_ops_test.py | 4 | ||||
-rw-r--r-- | tensorflow/compiler/tests/xla_ops_test.py | 2 | ||||
-rw-r--r-- | tensorflow/compiler/tests/xla_test.py | 6 |
8 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 4155342787..68f52e796c 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py index 17280e445b..900e84ab58 100644 --- a/tensorflow/compiler/tests/binary_ops_test.py +++ b/tensorflow/compiler/tests/binary_ops_test.py @@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase): equality_test=self.ListsAreClose) def testIntOps(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testBinary( gen_math_ops.truncate_div, np.array([3, 3, -1, -9, -8], dtype=dtype), @@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase): dtype(7), expected=np.array([[-6], [-5]], dtype=dtype)) - if dtype not in self.complex_types: # min/max not supported for complex + # min/max not supported for complex + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.maximum, np.array([1, 2], dtype=dtype), @@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([[70], [14]], dtype=dtype)) # Complex support for squared_difference is incidental, see b/68205550 - if dtype not in self.complex_types: + if dtype not in self.complex_types | {np.uint8, np.int8}: self._testBinary( math_ops.squared_difference, np.array([1, 2], dtype=dtype), @@ -567,7 +568,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, -2, -1, -5, 2], dtype=dtype)) def testIntDivision(self): - for dtype in self.int_types: + for dtype in self.signed_int_types: self._testDivision(dtype) def testFloatDivision(self): @@ -588,7 +589,7 @@ class BinaryOpsTest(xla_test.XLATestCase): expected=np.array([1, 1, -1, 0], dtype=dtype)) def testIntRemainder(self): - for dtype in self.int_types: + for dtype in self.signed_int_types - {np.int8}: self._testRemainder(dtype) def testFloatRemainder(self): diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl index a76f136736..114793352e 100644 --- a/tensorflow/compiler/tests/build_defs.bzl +++ b/tensorflow/compiler/tests/build_defs.bzl @@ -58,12 +58,12 @@ def tf_xla_py_test( if backend == "cpu": backend_args += [ "--test_device=XLA_CPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64", ] elif backend == "gpu": backend_args += [ "--test_device=XLA_GPU", - "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", + "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16", ] backend_tags += ["requires-gpu-sm35"] elif backend in plugins: diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py index 6e18344117..41fe42a26b 100644 --- a/tensorflow/compiler/tests/random_ops_test.py +++ b/tensorflow/compiler/tests/random_ops_test.py @@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase): """Test cases for random-number generating operators.""" def _random_types(self): - return set(self.numeric_types) - set(self.complex_types) + return set(self.numeric_types) - set( + self.complex_types) - {np.uint8, np.int8} def _testRngIsNotConstant(self, rng, dtype): # Tests that 'rng' does not always return the same value. diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py index 60c2337743..abc822ef36 100644 --- a/tensorflow/compiler/tests/reverse_sequence_op_test.py +++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py @@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase): def testSeqLength(self): for dtype in self.all_types: - for seq_dtype in self.int_types: + for seq_dtype in self.all_types & {np.int32, np.int64}: self._testBasic(dtype, seq_dtype) diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py index 5b0e57f83f..04ea004fe7 100644 --- a/tensorflow/compiler/tests/unary_ops_test.py +++ b/tensorflow/compiler/tests/unary_ops_test.py @@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase): self.assertAllClose(result[i], expected[i], rtol, atol) def testAllTypeOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype), np.array( @@ -633,7 +633,7 @@ class UnaryOpsTest(xla_test.XLATestCase): expected=np.array([-1, 0, -2, -17, -43], dtype=dtype)) def testNumericOps(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.int8, np.uint8}: self._assertOpOutputMatchesExpected( math_ops.abs, np.array([[2, -1]], dtype=dtype), diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py index 1e600c44e9..4cf88fc523 100644 --- a/tensorflow/compiler/tests/xla_ops_test.py +++ b/tensorflow/compiler/tests/xla_ops_test.py @@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase): dtype=dtype)) def testNeg(self): - for dtype in self.numeric_types: + for dtype in self.numeric_types - {np.uint8, np.int8}: self._assertOpOutputMatchesExpected( xla.neg, args=(np.array([1, 2, 3], dtype=dtype),), diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py index 88827cb53b..df5c81243a 100644 --- a/tensorflow/compiler/tests/xla_test.py +++ b/tensorflow/compiler/tests/xla_test.py @@ -101,6 +101,12 @@ class XLATestCase(test.TestCase): self._all_types = set( [dtype.as_numpy_dtype for dtype in self._all_tf_types]) self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types]) + self.signed_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if not dtype.is_unsigned) + self.unsigned_int_types = set(dtype.as_numpy_dtype + for dtype in self.int_tf_types + if dtype.is_unsigned) self._float_types = set( [dtype.as_numpy_dtype for dtype in self._float_tf_types]) self.complex_types = set([ |