diff options
author | 2018-01-22 23:29:57 -0800 | |
---|---|---|
committer | 2018-01-22 23:34:04 -0800 | |
commit | b5a1ac0dbaed5b3e0e2a620379a68f42edc87fb8 (patch) | |
tree | 7c4e0139f943dea41eb79775a018177efa3845b1 /tensorflow | |
parent | 26cdf8e3fdb0a13d19b2aedfc6c0ef1eb94c4c44 (diff) |
[XLA] Inline definitions of NativeToPrimitiveType<T>.
These functions are hot in the evaluator, and that's silly, because
they're pure constants. Just inline them.
PiperOrigin-RevId: 182889992
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/compiler/xla/primitive_util.cc | 73 | ||||
-rw-r--r-- | tensorflow/compiler/xla/primitive_util.h | 62 |
2 files changed, 47 insertions, 88 deletions
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc index 2bce56b7bd..143c9a2366 100644 --- a/tensorflow/compiler/xla/primitive_util.cc +++ b/tensorflow/compiler/xla/primitive_util.cc @@ -20,79 +20,6 @@ limitations under the License. namespace xla { namespace primitive_util { -template <> -PrimitiveType NativeToPrimitiveType<bool>() { - return PRED; -} - -// Unsigned integer -template <> -PrimitiveType NativeToPrimitiveType<uint8>() { - return U8; -} - -template <> -PrimitiveType NativeToPrimitiveType<uint16>() { - return U16; -} - -template <> -PrimitiveType NativeToPrimitiveType<uint32>() { - return U32; -} - -template <> -PrimitiveType NativeToPrimitiveType<uint64>() { - return U64; -} - -// Signed integer -template <> -PrimitiveType NativeToPrimitiveType<int8>() { - return S8; -} - -template <> -PrimitiveType NativeToPrimitiveType<int16>() { - return S16; -} - -template <> -PrimitiveType NativeToPrimitiveType<int32>() { - return S32; -} - -template <> -PrimitiveType NativeToPrimitiveType<int64>() { - return S64; -} - -// Floating point -template <> -PrimitiveType NativeToPrimitiveType<float>() { - return F32; -} - -template <> -PrimitiveType NativeToPrimitiveType<double>() { - return F64; -} - -template <> -PrimitiveType NativeToPrimitiveType<bfloat16>() { - return BF16; -} - -template <> -PrimitiveType NativeToPrimitiveType<half>() { - return F16; -} - -template <> -PrimitiveType NativeToPrimitiveType<complex64>() { - return C64; -} - bool IsFloatingPointType(PrimitiveType type) { return type == F16 || type == F32 || type == F64 || type == BF16; } diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h index cb4583d198..b26a10ade6 100644 --- a/tensorflow/compiler/xla/primitive_util.h +++ b/tensorflow/compiler/xla/primitive_util.h @@ -47,49 +47,81 @@ PrimitiveType NativeToPrimitiveType() { } // Declarations of specializations for each native type which correspond to a -// XLA primitive type. +// XLA primitive type. As an optimization, these are declared inline in the +// header. template <> -PrimitiveType NativeToPrimitiveType<bool>(); +inline PrimitiveType NativeToPrimitiveType<bool>() { + return PRED; +} // Unsigned integer template <> -PrimitiveType NativeToPrimitiveType<uint8>(); +inline PrimitiveType NativeToPrimitiveType<uint8>() { + return U8; +} template <> -PrimitiveType NativeToPrimitiveType<uint16>(); +inline PrimitiveType NativeToPrimitiveType<uint16>() { + return U16; +} template <> -PrimitiveType NativeToPrimitiveType<uint32>(); +inline PrimitiveType NativeToPrimitiveType<uint32>() { + return U32; +} template <> -PrimitiveType NativeToPrimitiveType<uint64>(); +inline PrimitiveType NativeToPrimitiveType<uint64>() { + return U64; +} // Signed integer template <> -PrimitiveType NativeToPrimitiveType<int8>(); +inline PrimitiveType NativeToPrimitiveType<int8>() { + return S8; +} template <> -PrimitiveType NativeToPrimitiveType<int16>(); +inline PrimitiveType NativeToPrimitiveType<int16>() { + return S16; +} template <> -PrimitiveType NativeToPrimitiveType<int32>(); +inline PrimitiveType NativeToPrimitiveType<int32>() { + return S32; +} template <> -PrimitiveType NativeToPrimitiveType<int64>(); +inline PrimitiveType NativeToPrimitiveType<int64>() { + return S64; +} // Floating point template <> -PrimitiveType NativeToPrimitiveType<float>(); +inline PrimitiveType NativeToPrimitiveType<float>() { + return F32; +} + template <> -PrimitiveType NativeToPrimitiveType<double>(); +inline PrimitiveType NativeToPrimitiveType<double>() { + return F64; +} + template <> -PrimitiveType NativeToPrimitiveType<half>(); +inline PrimitiveType NativeToPrimitiveType<half>() { + return F16; +} + template <> -PrimitiveType NativeToPrimitiveType<bfloat16>(); +inline PrimitiveType NativeToPrimitiveType<bfloat16>() { + return BF16; +} // Complex template <> -PrimitiveType NativeToPrimitiveType<complex64>(); +inline PrimitiveType NativeToPrimitiveType<complex64>() { + return C64; +} bool IsFloatingPointType(PrimitiveType type); |