aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-17 03:12:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 03:16:54 -0700
commitcac963862be3faa421c559f39033c9bfb3b27a51 (patch)
tree8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/tests
parentb1f4328517851e76cff3d4af8766e7e3446314ba (diff)
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op definitions in tensorflow are really inconsistent. I tried to infer whether the limitation is on signed types, index types or just arbitrary. In the latter case just int8/uint8 is blacklisted, we should probably lift that requirement at some point. PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/tests')
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py4
-rw-r--r--tensorflow/compiler/tests/binary_ops_test.py11
-rw-r--r--tensorflow/compiler/tests/build_defs.bzl4
-rw-r--r--tensorflow/compiler/tests/random_ops_test.py3
-rw-r--r--tensorflow/compiler/tests/reverse_sequence_op_test.py2
-rw-r--r--tensorflow/compiler/tests/unary_ops_test.py4
-rw-r--r--tensorflow/compiler/tests/xla_ops_test.py2
-rw-r--r--tensorflow/compiler/tests/xla_test.py6
8 files changed, 22 insertions, 14 deletions
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342787..68f52e796c 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
- minmax_types = set(self.numeric_types) - set(self.complex_types)
+ minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
- for output_type in self.int_types:
+ for output_type in minmax_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,
diff --git a/tensorflow/compiler/tests/binary_ops_test.py b/tensorflow/compiler/tests/binary_ops_test.py
index 17280e445b..900e84ab58 100644
--- a/tensorflow/compiler/tests/binary_ops_test.py
+++ b/tensorflow/compiler/tests/binary_ops_test.py
@@ -210,7 +210,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
equality_test=self.ListsAreClose)
def testIntOps(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testBinary(
gen_math_ops.truncate_div,
np.array([3, 3, -1, -9, -8], dtype=dtype),
@@ -287,7 +287,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
dtype(7),
expected=np.array([[-6], [-5]], dtype=dtype))
- if dtype not in self.complex_types: # min/max not supported for complex
+ # min/max not supported for complex
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.maximum,
np.array([1, 2], dtype=dtype),
@@ -337,7 +338,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([[70], [14]], dtype=dtype))
# Complex support for squared_difference is incidental, see b/68205550
- if dtype not in self.complex_types:
+ if dtype not in self.complex_types | {np.uint8, np.int8}:
self._testBinary(
math_ops.squared_difference,
np.array([1, 2], dtype=dtype),
@@ -567,7 +568,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, -2, -1, -5, 2], dtype=dtype))
def testIntDivision(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types:
self._testDivision(dtype)
def testFloatDivision(self):
@@ -588,7 +589,7 @@ class BinaryOpsTest(xla_test.XLATestCase):
expected=np.array([1, 1, -1, 0], dtype=dtype))
def testIntRemainder(self):
- for dtype in self.int_types:
+ for dtype in self.signed_int_types - {np.int8}:
self._testRemainder(dtype)
def testFloatRemainder(self):
diff --git a/tensorflow/compiler/tests/build_defs.bzl b/tensorflow/compiler/tests/build_defs.bzl
index a76f136736..114793352e 100644
--- a/tensorflow/compiler/tests/build_defs.bzl
+++ b/tensorflow/compiler/tests/build_defs.bzl
@@ -58,12 +58,12 @@ def tf_xla_py_test(
if backend == "cpu":
backend_args += [
"--test_device=XLA_CPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64",
]
elif backend == "gpu":
backend_args += [
"--test_device=XLA_GPU",
- "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
+ "--types=DT_HALF,DT_FLOAT,DT_DOUBLE,DT_UINT8,DT_INT8,DT_INT32,DT_INT64,DT_BOOL,DT_COMPLEX64,DT_BFLOAT16",
]
backend_tags += ["requires-gpu-sm35"]
elif backend in plugins:
diff --git a/tensorflow/compiler/tests/random_ops_test.py b/tensorflow/compiler/tests/random_ops_test.py
index 6e18344117..41fe42a26b 100644
--- a/tensorflow/compiler/tests/random_ops_test.py
+++ b/tensorflow/compiler/tests/random_ops_test.py
@@ -35,7 +35,8 @@ class RandomOpsTest(xla_test.XLATestCase):
"""Test cases for random-number generating operators."""
def _random_types(self):
- return set(self.numeric_types) - set(self.complex_types)
+ return set(self.numeric_types) - set(
+ self.complex_types) - {np.uint8, np.int8}
def _testRngIsNotConstant(self, rng, dtype):
# Tests that 'rng' does not always return the same value.
diff --git a/tensorflow/compiler/tests/reverse_sequence_op_test.py b/tensorflow/compiler/tests/reverse_sequence_op_test.py
index 60c2337743..abc822ef36 100644
--- a/tensorflow/compiler/tests/reverse_sequence_op_test.py
+++ b/tensorflow/compiler/tests/reverse_sequence_op_test.py
@@ -85,7 +85,7 @@ class ReverseSequenceTest(xla_test.XLATestCase):
def testSeqLength(self):
for dtype in self.all_types:
- for seq_dtype in self.int_types:
+ for seq_dtype in self.all_types & {np.int32, np.int64}:
self._testBasic(dtype, seq_dtype)
diff --git a/tensorflow/compiler/tests/unary_ops_test.py b/tensorflow/compiler/tests/unary_ops_test.py
index 5b0e57f83f..04ea004fe7 100644
--- a/tensorflow/compiler/tests/unary_ops_test.py
+++ b/tensorflow/compiler/tests/unary_ops_test.py
@@ -84,7 +84,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
self.assertAllClose(result[i], expected[i], rtol, atol)
def testAllTypeOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
array_ops.diag, np.array([1, 2, 3, 4], dtype=dtype),
np.array(
@@ -633,7 +633,7 @@ class UnaryOpsTest(xla_test.XLATestCase):
expected=np.array([-1, 0, -2, -17, -43], dtype=dtype))
def testNumericOps(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.int8, np.uint8}:
self._assertOpOutputMatchesExpected(
math_ops.abs,
np.array([[2, -1]], dtype=dtype),
diff --git a/tensorflow/compiler/tests/xla_ops_test.py b/tensorflow/compiler/tests/xla_ops_test.py
index 1e600c44e9..4cf88fc523 100644
--- a/tensorflow/compiler/tests/xla_ops_test.py
+++ b/tensorflow/compiler/tests/xla_ops_test.py
@@ -181,7 +181,7 @@ class XlaOpsTest(xla_test.XLATestCase, parameterized.TestCase):
dtype=dtype))
def testNeg(self):
- for dtype in self.numeric_types:
+ for dtype in self.numeric_types - {np.uint8, np.int8}:
self._assertOpOutputMatchesExpected(
xla.neg,
args=(np.array([1, 2, 3], dtype=dtype),),
diff --git a/tensorflow/compiler/tests/xla_test.py b/tensorflow/compiler/tests/xla_test.py
index 88827cb53b..df5c81243a 100644
--- a/tensorflow/compiler/tests/xla_test.py
+++ b/tensorflow/compiler/tests/xla_test.py
@@ -101,6 +101,12 @@ class XLATestCase(test.TestCase):
self._all_types = set(
[dtype.as_numpy_dtype for dtype in self._all_tf_types])
self._int_types = set([dtype.as_numpy_dtype for dtype in self.int_tf_types])
+ self.signed_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if not dtype.is_unsigned)
+ self.unsigned_int_types = set(dtype.as_numpy_dtype
+ for dtype in self.int_tf_types
+ if dtype.is_unsigned)
self._float_types = set(
[dtype.as_numpy_dtype for dtype in self._float_tf_types])
self.complex_types = set([