aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-25 08:23:57 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 08:27:36 -0700
commitb3771feab49e2122164737a860341727d08c2d8c (patch)
tree5fb440041db26ef96eb14e7491cb67fe06e7c3d4 /tensorflow/core/kernels/cast_op_gpu.cu.cc
parentbe3d22844025e42e177a21479f3ae73bc5351c1f (diff)
This change started with an intention of adding an attribute to cast ops to decide
whether bfloat16 casts should use truncation or rounding. This is a preparatory change before we switch the default float ==> bfloat16 cast to use rounding instead of truncation. The attribute added can then be specified on casts that rely on the truncation, e.g., the TensorFlow send/receive operations. It later emerged that the choice of doing truncation is useful more generally. Therefore, this change allows the new attribute to be used by all relevant casts to use truncation instead of rounding. PiperOrigin-RevId: 205996367
Diffstat (limited to 'tensorflow/core/kernels/cast_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc48
1 files changed, 37 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index 607e7f5efd..036996fca2 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -18,22 +18,19 @@ limitations under the License.
#define EIGEN_USE_GPU
#include "tensorflow/core/framework/bfloat16.h"
+#define SPECIALIZE_FOR_GPUS
#include "tensorflow/core/kernels/cast_op.h"
+#undef SPECIALIZE_FOR_GPUS
namespace tensorflow {
namespace functor {
typedef Eigen::GpuDevice GPUDevice;
-template <typename O, typename I>
-struct CastFunctor<GPUDevice, O, I> {
- void operator()(const GPUDevice& d, typename TTypes<O>::Flat o,
- typename TTypes<I>::ConstFlat i) {
- Cast<GPUDevice, O, I>(d, o, i);
- }
-};
+CAST_FUNCTORS(GPUDevice);
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
+
#define DEFINE_ALL_FROM(in_type) \
DEFINE(in_type, bool); \
DEFINE(in_type, uint8); \
@@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8);
DEFINE_ALL_FROM(int16);
DEFINE_ALL_FROM(int32);
DEFINE_ALL_FROM(int64);
-DEFINE_ALL_FROM(Eigen::half);
-DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
-DEFINE_ALL_FROM(std::complex<float>);
DEFINE_ALL_FROM(std::complex<double>);
-DEFINE(bfloat16, float);
DEFINE(float, bfloat16);
+#define DEFINE_ALL_TO_FLOAT(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half); \
+ DEFINE(out_type, float); \
+ DEFINE(out_type, std::complex<float>)
+
+#define DEFINE_ALL_TO_HALF(out_type) \
+ DEFINE(out_type, bool); \
+ DEFINE(out_type, uint8); \
+ DEFINE(out_type, uint16); \
+ DEFINE(out_type, uint32); \
+ DEFINE(out_type, uint64); \
+ DEFINE(out_type, int8); \
+ DEFINE(out_type, int16); \
+ DEFINE(out_type, int32); \
+ DEFINE(out_type, int64); \
+ DEFINE(out_type, Eigen::half)
+
+DEFINE_ALL_TO_HALF(Eigen::half);
+DEFINE_ALL_TO_HALF(bfloat16);
+DEFINE_ALL_TO_FLOAT(float);
+DEFINE_ALL_TO_FLOAT(std::complex<float>);
+
+#undef DEFINE_ALL_TO_FLOAT
+#undef DEFINE_ALL_TO_HALF
#undef DEFINE_ALL_FROM
#undef DEFINE