diff options
author | 2018-09-17 03:12:38 -0700 | |
---|---|---|
committer | 2018-09-17 03:16:54 -0700 | |
commit | cac963862be3faa421c559f39033c9bfb3b27a51 (patch) | |
tree | 8418eb6b786f0c46d0738ca54084583330012a42 /tensorflow/compiler/jit | |
parent | b1f4328517851e76cff3d4af8766e7e3446314ba (diff) |
[XLA:TF] Enable int8 and uint8 support in the bridge for CPU/GPU
The test changes are awkward. None of these are XLA bugs, it's just that the op
definitions in tensorflow are really inconsistent. I tried to infer whether the
limitation is on signed types, index types or just arbitrary. In the latter
case just int8/uint8 is blacklisted, we should probably lift that requirement
at some point.
PiperOrigin-RevId: 213243906
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r-- | tensorflow/compiler/jit/xla_cpu_device.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/jit/xla_gpu_device.cc | 6 |
2 files changed, 6 insertions, 5 deletions
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc index 7e159e3171..1afc305abe 100644 --- a/tensorflow/compiler/jit/xla_cpu_device.cc +++ b/tensorflow/compiler/jit/xla_cpu_device.cc @@ -65,8 +65,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 7> kAllXlaCpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; +constexpr std::array<DataType, 9> kAllXlaCpuTypes = { + {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes); diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc index ef4466f005..4cf556524d 100644 --- a/tensorflow/compiler/jit/xla_gpu_device.cc +++ b/tensorflow/compiler/jit/xla_gpu_device.cc @@ -74,9 +74,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory); // Kernel registrations -constexpr std::array<DataType, 8> kAllXlaGpuTypes = { - {DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, - DT_BFLOAT16}}; +constexpr std::array<DataType, 10> kAllXlaGpuTypes = { + {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, + DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes); REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes); |