diff options
author | 2017-12-15 16:41:44 -0800 | |
---|---|---|
committer | 2017-12-15 16:45:23 -0800 | |
commit | dcb0666a2be2c78d1f36984ef45910998f19e50b (patch) | |
tree | 94a065f38a5b426b19bb9ab10b52b56acd56fe8b /tensorflow/core/kernels/slice_op.cc | |
parent | 4f4abcacedcba5430e03320f39205d2f327df2ac (diff) |
add bfloat16 support to some GPU ops: concat, constant, fill, pack, reshape,
slice, split, unpack
PiperOrigin-RevId: 179255814
Diffstat (limited to 'tensorflow/core/kernels/slice_op.cc')
-rw-r--r-- | tensorflow/core/kernels/slice_op.cc | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc index d46701749b..a9e31cc336 100644 --- a/tensorflow/core/kernels/slice_op.cc +++ b/tensorflow/core/kernels/slice_op.cc @@ -439,7 +439,7 @@ namespace functor { DECLARE_CPU_SPEC(T, 7); TF_CALL_ALL_TYPES(DECLARE_FOR_N); -DECLARE_FOR_N(bfloat16); +TF_CALL_bfloat16(DECLARE_FOR_N); #undef DECLARE_FOR_N #undef DECLARE_CPU_SPEC @@ -456,7 +456,7 @@ DECLARE_FOR_N(bfloat16); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -REGISTER_SLICE(bfloat16); +TF_CALL_bfloat16(REGISTER_SLICE); #undef REGISTER_SLICE #else #define REGISTER_SLICE(type) \ @@ -469,7 +469,7 @@ REGISTER_SLICE(bfloat16); TF_CALL_POD_STRING_TYPES(REGISTER_SLICE); TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE); -REGISTER_SLICE(bfloat16); +TF_CALL_bfloat16(REGISTER_SLICE); #undef REGISTER_SLICE #endif // INTEL_MKL @@ -497,6 +497,7 @@ namespace functor { TF_CALL_GPU_NUMBER_TYPES(DECLARE_FOR_N); TF_CALL_complex64(DECLARE_FOR_N); TF_CALL_complex128(DECLARE_FOR_N); +TF_CALL_bfloat16(DECLARE_FOR_N); DECLARE_FOR_N(int32); #undef DECLARE_FOR_N @@ -515,6 +516,7 @@ DECLARE_FOR_N(int32); TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); TF_CALL_complex64(REGISTER_GPU); TF_CALL_complex128(REGISTER_GPU); +TF_CALL_bfloat16(REGISTER_GPU); // A special GPU kernel for int32. // TODO(b/25387198): Also enable int32 in device memory. This kernel |