aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/jit
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-09-20 03:14:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-20 03:16:50 -0700
commit90d084e0c42232043c186e66093b67800fb30fba (patch)
tree230d7998ea42af3efd59b0d25312eaa54efce5de /tensorflow/compiler/jit
parent9604413da7a27f5718bb88d407d13476dbef5b82 (diff)
[XLA:TF] Whitelist quantized types for CPU/GPU
These have the same behavior as unquantized types so we can just pass them through to XLA (which converts them to unquantized types). They're supposed to be used with special ops, none of which are currently implemented by XLA. Casting (without quantization) and basic math works fine though. These do not have a corresponding numpy type, so only tests using TF types will see them. PiperOrigin-RevId: 213781650
Diffstat (limited to 'tensorflow/compiler/jit')
-rw-r--r--tensorflow/compiler/jit/xla_cpu_device.cc6
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc6
2 files changed, 6 insertions, 6 deletions
diff --git a/tensorflow/compiler/jit/xla_cpu_device.cc b/tensorflow/compiler/jit/xla_cpu_device.cc
index 1afc305abe..e26fa27b31 100644
--- a/tensorflow/compiler/jit/xla_cpu_device.cc
+++ b/tensorflow/compiler/jit/xla_cpu_device.cc
@@ -65,9 +65,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_CPU, XlaCpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 9> kAllXlaCpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL}};
+constexpr std::array<DataType, 12> kAllXlaCpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_CPU, XlaLocalLaunchOp, kAllXlaCpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_CPU, kAllXlaCpuTypes);
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index 4cf556524d..c386984930 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -74,9 +74,9 @@ REGISTER_LOCAL_DEVICE_FACTORY(DEVICE_XLA_GPU, XlaGpuDeviceFactory);
// Kernel registrations
-constexpr std::array<DataType, 10> kAllXlaGpuTypes = {
- {DT_UINT8, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE,
- DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
+constexpr std::array<DataType, 13> kAllXlaGpuTypes = {
+ {DT_UINT8, DT_QUINT8, DT_INT8, DT_QINT8, DT_INT32, DT_QINT32, DT_INT64,
+ DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}};
REGISTER_XLA_LAUNCH_KERNEL(DEVICE_XLA_GPU, XlaLocalLaunchOp, kAllXlaGpuTypes);
REGISTER_XLA_DEVICE_KERNELS(DEVICE_XLA_GPU, kAllXlaGpuTypes);