aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/slice_op.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-15 16:41:44 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 16:45:23 -0800
commitdcb0666a2be2c78d1f36984ef45910998f19e50b (patch)
tree94a065f38a5b426b19bb9ab10b52b56acd56fe8b /tensorflow/core/kernels/slice_op.cc
parent4f4abcacedcba5430e03320f39205d2f327df2ac (diff)
add bfloat16 support to some GPU ops: concat, constant, fill, pack, reshape,
slice, split, unpack PiperOrigin-RevId: 179255814
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r--tensorflow/core/kernels/slice_op.cc8
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index d46701749b..a9e31cc336 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -439,7 +439,7 @@ namespace functor {
DECLARE_CPU_SPEC(T, 7);
TF_CALL_ALL_TYPES(DECLARE_FOR_N);
-DECLARE_FOR_N(bfloat16);
+TF_CALL_bfloat16(DECLARE_FOR_N);
#undef DECLARE_FOR_N
#undef DECLARE_CPU_SPEC
@@ -456,7 +456,7 @@ DECLARE_FOR_N(bfloat16);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-REGISTER_SLICE(bfloat16);
+TF_CALL_bfloat16(REGISTER_SLICE);
#undef REGISTER_SLICE
#else
#define REGISTER_SLICE(type) \
@@ -469,7 +469,7 @@ REGISTER_SLICE(bfloat16);
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
-REGISTER_SLICE(bfloat16);
+TF_CALL_bfloat16(REGISTER_SLICE);
#undef REGISTER_SLICE
#endif // INTEL_MKL
@@ -497,6 +497,7 @@ namespace functor {
TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N);
TF_CALL_complex64(DECLARE_FOR_N);
TF_CALL_complex128(DECLARE_FOR_N);
+TF_CALL_bfloat16(DECLARE_FOR_N);
DECLARE_FOR_N(int32);
#undef DECLARE_FOR_N
@@ -515,6 +516,7 @@ DECLARE_FOR_N(int32);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_bfloat16(REGISTER_GPU);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel