diff options
author | 2016-08-12 23:02:16 +0100 | |
---|---|---|
committer | 2016-08-12 15:02:16 -0700 | |
commit | b0aa789423270adfe58496a78cbfd1509f2374f0 (patch) | |
tree | 77cd1c686714d78e853f88b83ed41a0f1e2a76fd /tensorflow/core/kernels/cast_op_gpu.cu.cc | |
parent | bf31051225ce53c1c88fd45ee117a49645153770 (diff) |
Add complex dtype support to cast (#3718)
Diffstat (limited to 'tensorflow/core/kernels/cast_op_gpu.cu.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op_gpu.cu.cc | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc index ada442b352..9c9e9e7658 100644 --- a/tensorflow/core/kernels/cast_op_gpu.cu.cc +++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc @@ -34,17 +34,19 @@ struct CastFunctor<GPUDevice, O, I> { }; #define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I> -#define DEFINE_ALL_FROM(in_type) \ - DEFINE(in_type, bool); \ - DEFINE(in_type, uint8); \ - DEFINE(in_type, int8); \ - DEFINE(in_type, uint16); \ - DEFINE(in_type, int16); \ - DEFINE(in_type, int32); \ - DEFINE(in_type, int64); \ - DEFINE(in_type, Eigen::half); \ - DEFINE(in_type, float); \ - DEFINE(in_type, double) +#define DEFINE_ALL_FROM(in_type) \ + DEFINE(in_type, bool); \ + DEFINE(in_type, uint8); \ + DEFINE(in_type, int8); \ + DEFINE(in_type, uint16); \ + DEFINE(in_type, int16); \ + DEFINE(in_type, int32); \ + DEFINE(in_type, int64); \ + DEFINE(in_type, Eigen::half); \ + DEFINE(in_type, float); \ + DEFINE(in_type, double); \ + DEFINE(in_type, std::complex<float>); \ + DEFINE(in_type, std::complex<double>) DEFINE_ALL_FROM(bool); DEFINE_ALL_FROM(uint8); @@ -56,6 +58,8 @@ DEFINE_ALL_FROM(int64); DEFINE_ALL_FROM(Eigen::half); DEFINE_ALL_FROM(float); DEFINE_ALL_FROM(double); +DEFINE_ALL_FROM(std::complex<float>); +DEFINE_ALL_FROM(std::complex<double>); DEFINE(bfloat16, float); DEFINE(float, bfloat16); |