aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-17 03:12:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 03:16:54 -0700
commitcac963862be3faa421c559f39033c9bfb3b27a51 (patch)
tree8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/jit
parentb1f4328517851e76cff3d4af8766e7e3446314ba (diff)
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op definitions in tensorflow are really inconsistent. I tried to infer whether the limitation is on signed types, index types or just arbitrary. In the latter case just int8/uint8 is blacklisted, we should probably lift that requirement at some point. PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc5
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc6
2 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 7e159e3171..1afc305abe 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -65,8 +65,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 7> kAllXlaCpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 9> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ef4466f005..4cf556524d 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -74,9 +74,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 8> kAllXlaGpuTypes = {
- {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL,
- DT_BFLOAT16}};
+constexpr std::array<DataType, 10> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
+ DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);