aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op_gpu.cu.cc
diff options
context:
space:
mode:
authorGravatar Igor Babuschkin <igor@babuschk.in>2016-08-12 23:02:16 +0100
committerGravatar Benoit Steiner <benoitsteiner@users.noreply.github.com>2016-08-12 15:02:16 -0700
commitb0aa789423270adfe58496a78cbfd1509f2374f0 (patch)
tree77cd1c686714d78e853f88b83ed41a0f1e2a76fd /tensorflow/core/kernels/cast_op_gpu.cu.cc
parentbf31051225ce53c1c88fd45ee117a49645153770 (diff)
Add complex dtype support to cast (#3718)
Diffstat (limited to 'tensorflow/core/kernels/cast_op_gpu.cu.cc')
-rw-r--r--tensorflow/core/kernels/cast_op_gpu.cu.cc26
1 files changed, 15 insertions, 11 deletions
diff --git a/tensorflow/core/kernels/cast_op_gpu.cu.cc b/tensorflow/core/kernels/cast_op_gpu.cu.cc
index ada442b352..9c9e9e7658 100644
--- a/tensorflow/core/kernels/cast_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/cast_op_gpu.cu.cc
@@ -34,17 +34,19 @@ struct CastFunctor<GPUDevice, O, I> {
};
#define DEFINE(O, I) template struct CastFunctor<GPUDevice, O, I>
-#define DEFINE_ALL_FROM(in_type) \
- DEFINE(in_type, bool); \
- DEFINE(in_type, uint8); \
- DEFINE(in_type, int8); \
- DEFINE(in_type, uint16); \
- DEFINE(in_type, int16); \
- DEFINE(in_type, int32); \
- DEFINE(in_type, int64); \
- DEFINE(in_type, Eigen::half); \
- DEFINE(in_type, float); \
- DEFINE(in_type, double)
+#define DEFINE_ALL_FROM(in_type) \
+ DEFINE(in_type, bool); \
+ DEFINE(in_type, uint8); \
+ DEFINE(in_type, int8); \
+ DEFINE(in_type, uint16); \
+ DEFINE(in_type, int16); \
+ DEFINE(in_type, int32); \
+ DEFINE(in_type, int64); \
+ DEFINE(in_type, Eigen::half); \
+ DEFINE(in_type, float); \
+ DEFINE(in_type, double); \
+ DEFINE(in_type, std::complex<float>); \
+ DEFINE(in_type, std::complex<double>)
DEFINE_ALL_FROM(bool);
DEFINE_ALL_FROM(uint8);
@@ -56,6 +58,8 @@ DEFINE_ALL_FROM(int64);
DEFINE_ALL_FROM(Eigen::half);
DEFINE_ALL_FROM(float);
DEFINE_ALL_FROM(double);
+DEFINE_ALL_FROM(std::complex<float>);
+DEFINE_ALL_FROM(std::complex<double>);
DEFINE(bfloat16, float);
DEFINE(float, bfloat16);