aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/tests/argminmax_test.py
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-17 03:12:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 03:16:54 -0700
commitcac963862be3faa421c559f39033c9bfb3b27a51 (patch)
tree8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/tests/argminmax_test.py
parentb1f4328517851e76cff3d4af8766e7e3446314ba (diff)
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op definitions in tensorflow are really inconsistent. I tried to infer whether the limitation is on signed types, index types or just arbitrary. In the latter case just int8/uint8 is blacklisted, we should probably lift that requirement at some point. PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/tests/argminmax_test.py')
-rw-r--r--tensorflow/compiler/tests/argminmax_test.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py
index 4155342787..68f52e796c 100644
--- a/tensorflow/compiler/tests/argminmax_test.py
+++ b/tensorflow/compiler/tests/argminmax_test.py
@@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase):
def testArgMinMax(self):
# Complex numbers do not support argmin/argmax.
- minmax_types = set(self.numeric_types) - set(self.complex_types)
+ minmax_types = self.all_types & {np.int32, np.int64}
for dtype in minmax_types:
# output_type is a numpy data type that is used to specify the desired
# output type of the op as well as to convert the Python number to the
# array scalar of the type.
- for output_type in self.int_types:
+ for output_type in minmax_types:
self._assertOpOutputMatchesExpected(
math_ops.argmax,
axis=0,