diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op_gpu.cu.cc | 48 |
1 files changed, 37 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index 607e7f5efd..036996fca2 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -18,22 +18,19 @@ limitations under the License. #define EIGEN_USE_GPU #include "tensorflow/core/framework/bfloat16.h" +#define SPECIALIZE_FOR_GPUS #include "tensorflow/core/kernels/cast_op.h" +#undef SPECIALIZE_FOR_GPUS namespace tensorflow { namespace functor { typedef Eigen::GpuDevice GPUDevice; -template <typename O, typename I> -struct CastFunctor<GPUDevice, O, I> { - void operator()(const GPUDevice& d, typename TTypes<O>::Flat o, - typename TTypes<I>::ConstFlat i) { - Cast<GPUDevice, O, I>(d, o, i); - } -}; +CAST_FUNCTORS(GPUDevice); #define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I> + #define DEFINE_ALL_FROM(in_type) \ DEFINE(in_type, bool); \ DEFINE(in_type, uint8); \ @@ -59,14 +56,43 @@ DEFINE_ALL_FROM(int8); DEFINE_ALL_FROM(int16); DEFINE_ALL_FROM(int32); DEFINE_ALL_FROM(int64); -DEFINE_ALL_FROM(Eigen::half); -DEFINE_ALL_FROM(float); DEFINE_ALL_FROM(double); -DEFINE_ALL_FROM(std::complex<float>); DEFINE_ALL_FROM(std::complex<double>); -DEFINE(bfloat16, float); DEFINE(float, bfloat16); +#define DEFINE_ALL_TO_FLOAT(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half); \ + DEFINE(out_type, float); \ + DEFINE(out_type, std::complex<float>) + +#define DEFINE_ALL_TO_HALF(out_type) \ + DEFINE(out_type, bool); \ + DEFINE(out_type, uint8); \ + DEFINE(out_type, uint16); \ + DEFINE(out_type, uint32); \ + DEFINE(out_type, uint64); \ + DEFINE(out_type, int8); \ + DEFINE(out_type, int16); \ + DEFINE(out_type, int32); \ + DEFINE(out_type, int64); \ + DEFINE(out_type, Eigen::half) + +DEFINE_ALL_TO_HALF(Eigen::half); +DEFINE_ALL_TO_HALF(bfloat16); +DEFINE_ALL_TO_FLOAT(float); +DEFINE_ALL_TO_FLOAT(std::complex<float>); + +#undef DEFINE_ALL_TO_FLOAT +#undef DEFINE_ALL_TO_HALF #undef DEFINE_ALL_FROM #undef DEFINE |