aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r--tensorflow/core/kernels/cast_op.cc56
1 files changed, 55 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc
index ab82c247d6..562934ed63 100644
--- a/tensorflow/core/kernels/cast_op.cc
+++ b/tensorflow/core/kernels/cast_op.cc
@@ -34,6 +34,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
#define CURRY_TYPES2(FN, arg0) \
FN(arg0, bool); \
@@ -206,6 +209,52 @@ REGISTER_CAST_GPU(bfloat16, float);
#undef REGISTER_CAST_GPU
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+class SyclCastOp : public CastOpBase {
+ public:
+ explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) {
+ OP_REQUIRES_OK(ctx, Prepare());
+ }
+
+ private:
+ Status Prepare() {
+ if (src_dtype_ == dst_dtype_) {
+ work_ = nullptr; // Identity
+ return Status::OK();
+ }
+ if (src_dtype_ == DT_BOOL) {
+ work_ = GetSyclCastFromBool(dst_dtype_);
+ } else if (src_dtype_ == DT_INT32) {
+ work_ = GetSyclCastFromInt32(dst_dtype_);
+ } else if (src_dtype_ == DT_INT64) {
+ work_ = GetSyclCastFromInt64(dst_dtype_);
+ } else if (src_dtype_ == DT_FLOAT) {
+ work_ = GetSyclCastFromFloat(dst_dtype_);
+ } else if (src_dtype_ == DT_DOUBLE) {
+ work_ = GetSyclCastFromDouble(dst_dtype_);
+ }
+
+ return work_ == nullptr ? Unimplemented() : Status::OK();
+ }
+};
+
+#define REGISTER_CAST_SYCL(srctype, dsttype) \
+ REGISTER_KERNEL_BUILDER(Name("Cast") \
+ .TypeConstraint<srctype>("SrcT") \
+ .TypeConstraint<dsttype>("DstT") \
+ .Device(DEVICE_SYCL), \
+ SyclCastOp)
+
+CURRY_TYPES2(REGISTER_CAST_SYCL, bool);
+CURRY_TYPES2(REGISTER_CAST_SYCL, int32);
+CURRY_TYPES2(REGISTER_CAST_SYCL, int64);
+CURRY_TYPES2(REGISTER_CAST_SYCL, float);
+CURRY_TYPES2(REGISTER_CAST_SYCL, double);
+
+#undef REGISTER_CAST_SYCL
+
+#endif // TENSORFLOW_USE_SYCL
+
#undef CURRY_TYPES2
// HostCast differs from Cast in that its input and output are in host memory.
@@ -213,5 +262,10 @@ REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp);
REGISTER_KERNEL_BUILDER(
Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"),
CpuCastOp);
-
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(
+ Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"),
+ CpuCastOp);
+#endif // TENSORFLOW_USE_SYCL
} // end namespace tensorflow
+