diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-17 03:12:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-17 03:16:54 -0700 |
commit | cac963862be3faa421c559f39033c9bfb3b27a51 (patch) | |
tree | 8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/tests/argminmax_test.py | |
parent | b1f4328517851e76cff3d4af8766e7e3446314ba (diff) |
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op
definitions in tensorflow are really inconsistent. I tried to infer whether the
limitation is on signed types, index types or just arbitrary. In the latter
case just int8/uint8 is blacklisted, we should probably lift that requirement
at some point.
PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/tests/argminmax_test.py')
-rw-r--r-- | tensorflow/compiler/tests/argminmax_test.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/compiler/tests/argminmax_test.py b/tensorflow/compiler/tests/argminmax_test.py index 4155342787..68f52e796c 100644 --- a/tensorflow/compiler/tests/argminmax_test.py +++ b/tensorflow/compiler/tests/argminmax_test.py @@ -50,12 +50,12 @@ class ArgMinMaxTest(xla_test.XLATestCase): def testArgMinMax(self): # Complex numbers do not support argmin/argmax. - minmax_types = set(self.numeric_types) - set(self.complex_types) + minmax_types = self.all_types & {np.int32, np.int64} for dtype in minmax_types: # output_type is a numpy data type that is used to specify the desired # output type of the op as well as to convert the Python number to the # array scalar of the type. - for output_type in self.int_types: + for output_type in minmax_types: self._assertOpOutputMatchesExpected( math_ops.argmax, axis=0, |