diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op.cc')
-rw-r--r-- | tensorflow/core/kernels/cast_op.cc | 56 |
1 files changed, 55 insertions, 1 deletions
diff --git a/tensorflow/core/kernels/cast_op.cc b/tensorflow/core/kernels/cast_op.cc index ab82c247d6..562934ed63 100644 --- a/tensorflow/core/kernels/cast_op.cc +++ b/tensorflow/core/kernels/cast_op.cc @@ -34,6 +34,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL #define CURRY_TYPES2(FN, arg0) \ FN(arg0, bool); \ @@ -206,6 +209,52 @@ REGISTER_CAST_GPU(bfloat16, float); #undef REGISTER_CAST_GPU #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +class SyclCastOp : public CastOpBase { + public: + explicit SyclCastOp(OpKernelConstruction* ctx) : CastOpBase(ctx) { + OP_REQUIRES_OK(ctx, Prepare()); + } + + private: + Status Prepare() { + if (src_dtype_ == dst_dtype_) { + work_ = nullptr; // Identity + return Status::OK(); + } + if (src_dtype_ == DT_BOOL) { + work_ = GetSyclCastFromBool(dst_dtype_); + } else if (src_dtype_ == DT_INT32) { + work_ = GetSyclCastFromInt32(dst_dtype_); + } else if (src_dtype_ == DT_INT64) { + work_ = GetSyclCastFromInt64(dst_dtype_); + } else if (src_dtype_ == DT_FLOAT) { + work_ = GetSyclCastFromFloat(dst_dtype_); + } else if (src_dtype_ == DT_DOUBLE) { + work_ = GetSyclCastFromDouble(dst_dtype_); + } + + return work_ == nullptr ? Unimplemented() : Status::OK(); + } +}; + +#define REGISTER_CAST_SYCL(srctype, dsttype) \ + REGISTER_KERNEL_BUILDER(Name("Cast") \ + .TypeConstraint<srctype>("SrcT") \ + .TypeConstraint<dsttype>("DstT") \ + .Device(DEVICE_SYCL), \ + SyclCastOp) + +CURRY_TYPES2(REGISTER_CAST_SYCL, bool); +CURRY_TYPES2(REGISTER_CAST_SYCL, int32); +CURRY_TYPES2(REGISTER_CAST_SYCL, int64); +CURRY_TYPES2(REGISTER_CAST_SYCL, float); +CURRY_TYPES2(REGISTER_CAST_SYCL, double); + +#undef REGISTER_CAST_SYCL + +#endif // TENSORFLOW_USE_SYCL + #undef CURRY_TYPES2 // HostCast differs from Cast in that its input and output are in host memory. @@ -213,5 +262,10 @@ REGISTER_KERNEL_BUILDER(Name("_HostCast").Device(DEVICE_CPU), CpuCastOp); REGISTER_KERNEL_BUILDER( Name("_HostCast").Device(DEVICE_GPU).HostMemory("x").HostMemory("y"), CpuCastOp); - +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL_BUILDER( + Name("_HostCast").Device(DEVICE_SYCL).HostMemory("x").HostMemory("y"), + CpuCastOp); +#endif // TENSORFLOW_USE_SYCL } // end namespace tensorflow + |