aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-01-22 23:29:57 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-22 23:34:04 -0800
commitb5a1ac0dbaed5b3e0e2a620379a68f42edc87fb8 (patch)
tree7c4e0139f943dea41eb79775a018177efa3845b1 /tensorflow
parent26cdf8e3fdb0a13d19b2aedfc6c0ef1eb94c4c44 (diff)
[XLA] Inline definitions of NativeToPrimitiveType<T>.
These functions are hot in the evaluator, and that's silly, because they're pure constants. Just inline them. PiperOrigin-RevId: 182889992
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/primitive_util.cc73
-rw-r--r--tensorflow/compiler/xla/primitive_util.h62
2 files changed, 47 insertions, 88 deletions
diff --git a/tensorflow/compiler/xla/primitive_util.cc b/tensorflow/compiler/xla/primitive_util.cc
index 2bce56b7bd..143c9a2366 100644
--- a/tensorflow/compiler/xla/primitive_util.cc
+++ b/tensorflow/compiler/xla/primitive_util.cc
@@ -20,79 +20,6 @@ limitations under the License.
namespace xla {
namespace primitive_util {
-template <>
-PrimitiveType NativeToPrimitiveType<bool>() {
- return PRED;
-}
-
-// Unsigned integer
-template <>
-PrimitiveType NativeToPrimitiveType<uint8>() {
- return U8;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint16>() {
- return U16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint32>() {
- return U32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<uint64>() {
- return U64;
-}
-
-// Signed integer
-template <>
-PrimitiveType NativeToPrimitiveType<int8>() {
- return S8;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int16>() {
- return S16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int32>() {
- return S32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<int64>() {
- return S64;
-}
-
-// Floating point
-template <>
-PrimitiveType NativeToPrimitiveType<float>() {
- return F32;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<double>() {
- return F64;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>() {
- return BF16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<half>() {
- return F16;
-}
-
-template <>
-PrimitiveType NativeToPrimitiveType<complex64>() {
- return C64;
-}
-
bool IsFloatingPointType(PrimitiveType type) {
return type == F16 || type == F32 || type == F64 || type == BF16;
}
diff --git a/tensorflow/compiler/xla/primitive_util.h b/tensorflow/compiler/xla/primitive_util.h
index cb4583d198..b26a10ade6 100644
--- a/tensorflow/compiler/xla/primitive_util.h
+++ b/tensorflow/compiler/xla/primitive_util.h
@@ -47,49 +47,81 @@ PrimitiveType NativeToPrimitiveType() {
}
// Declarations of specializations for each native type which correspond to a
-// XLA primitive type.
+// XLA primitive type. As an optimization, these are declared inline in the
+// header.
template <>
-PrimitiveType NativeToPrimitiveType<bool>();
+inline PrimitiveType NativeToPrimitiveType<bool>() {
+ return PRED;
+}
// Unsigned integer
template <>
-PrimitiveType NativeToPrimitiveType<uint8>();
+inline PrimitiveType NativeToPrimitiveType<uint8>() {
+ return U8;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint16>();
+inline PrimitiveType NativeToPrimitiveType<uint16>() {
+ return U16;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint32>();
+inline PrimitiveType NativeToPrimitiveType<uint32>() {
+ return U32;
+}
template <>
-PrimitiveType NativeToPrimitiveType<uint64>();
+inline PrimitiveType NativeToPrimitiveType<uint64>() {
+ return U64;
+}
// Signed integer
template <>
-PrimitiveType NativeToPrimitiveType<int8>();
+inline PrimitiveType NativeToPrimitiveType<int8>() {
+ return S8;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int16>();
+inline PrimitiveType NativeToPrimitiveType<int16>() {
+ return S16;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int32>();
+inline PrimitiveType NativeToPrimitiveType<int32>() {
+ return S32;
+}
template <>
-PrimitiveType NativeToPrimitiveType<int64>();
+inline PrimitiveType NativeToPrimitiveType<int64>() {
+ return S64;
+}
// Floating point
template <>
-PrimitiveType NativeToPrimitiveType<float>();
+inline PrimitiveType NativeToPrimitiveType<float>() {
+ return F32;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<double>();
+inline PrimitiveType NativeToPrimitiveType<double>() {
+ return F64;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<half>();
+inline PrimitiveType NativeToPrimitiveType<half>() {
+ return F16;
+}
+
template <>
-PrimitiveType NativeToPrimitiveType<bfloat16>();
+inline PrimitiveType NativeToPrimitiveType<bfloat16>() {
+ return BF16;
+}
// Complex
template <>
-PrimitiveType NativeToPrimitiveType<complex64>();
+inline PrimitiveType NativeToPrimitiveType<complex64>() {
+ return C64;
+}
bool IsFloatingPointType(PrimitiveType type);