diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-09-20 03:14:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-20 03:16:50 -0700 |
commit | 90d084e0c42232043c186e66093b67800fb30fba (patch) | |
tree | 230d7998ea42af3efd59b0d25312eaa54efce5de /tensorflow/compiler/tf2xla | |
parent | 9604413da7a27f5718bb88d407d13476dbef5b82 (diff) |
[XLA:TF] Whitelist quantized types for CPU/GPU
These have the same behavior as unquantized types so we can just pass them
through to XLA (which converts them to unquantized types). They're supposed to
be used with special ops, none of which are currently implemented by XLA.
Casting (without quantization) and basic math works fine though.
These do not have a corresponding numpy type, so only tests using TF types will
see them.
PiperOrigin-RevId: 213781650
Diffstat (limited to 'tensorflow/compiler/tf2xla')
-rw-r--r-- | tensorflow/compiler/tf2xla/xla_op_registry.h | 15 |
1 files changed, 8 insertions, 7 deletions
diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index a4b624820a..4b2c2bacd6 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -51,13 +51,14 @@ constexpr std::array<DataType, 11> kNumericTypes = { {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BFLOAT16}}; -constexpr std::array<DataType, 11> kCpuAllTypes = { - {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; - -constexpr std::array<DataType, 12> kGpuAllTypes = { - {DT_UINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_INT32, DT_INT64, DT_HALF, - DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, DT_BFLOAT16}}; +constexpr std::array<DataType, 14> kCpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL}}; + +constexpr std::array<DataType, 15> kGpuAllTypes = { + {DT_UINT8, DT_QUINT8, DT_UINT32, DT_UINT64, DT_INT8, DT_QINT8, DT_INT32, + DT_QINT32, DT_INT64, DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64, DT_BOOL, + DT_BFLOAT16}}; // Class that manages registrations of operators and devices for the XLA JIT. // Not thread-safe. |