diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-07-25 08:23:57 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 08:27:36 -0700 |
commit | b3771feab49e2122164737a860341727d08c2d8c (patch) | |
tree | 5fb440041db26ef96eb14e7491cb67fe06e7c3d4 /tensorflow/core/kernels/cast_op_impl_uint32.cc | |
parent | be3d22844025e42e177a21479f3ae73bc5351c1f (diff) |
This change started with an intention of adding an attribute to cast ops to decide
whether bfloat16 casts should use truncation or rounding.
This is a preparatory change before we switch the default float ==> bfloat16 cast
to use rounding instead of truncation. The attribute added can then be specified
on casts that rely on the truncation, e.g., the TensorFlow send/receive operations.
It later emerged that the choice of doing truncation is useful more generally.
Therefore, this change allows the new attribute to be used by all relevant casts
to use truncation instead of rounding.
PiperOrigin-RevId: 205996367
Diffstat (limited to 'tensorflow/core/kernels/cast_op_impl_uint32.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl_uint32.cc | 9 |
1 files changed, 3 insertions, 6 deletions
diff --git a/tensorflow/core/kernels/cast_op_impl_uint32.cc b/tensorflow/core/kernels/cast_op_impl_uint32.cc index d1a854d98b..86f5961bcc 100644 --- a/tensorflow/core/kernels/cast_op_impl_uint32.cc +++ b/tensorflow/core/kernels/cast_op_impl_uint32.cc @@ -20,15 +20,13 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetCpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetCpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3(CAST_CASE, CPUDevice, uint32); return nullptr; } #if GOOGLE_CUDA -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetGpuCastFromUint32(DataType dst_dtype) { +CastFunctorType GetGpuCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_BF16(CAST_CASE, GPUDevice, uint32); return nullptr; } @@ -36,8 +34,7 @@ GetGpuCastFromUint32(DataType dst_dtype) { #ifdef TENSORFLOW_USE_SYCL typedef Eigen::SyclDevice SYCLDevice; -std::function<void(OpKernelContext*, const Tensor&, Tensor*)> -GetSyclCastFromUint32(DataType dst_dtype) { +CastFunctorType GetSyclCastFromUint32(DataType dst_dtype) { CURRY_TYPES3_NO_HALF(CAST_CASE, SYCLDevice, uint32); return nullptr; } |