diff options
Diffstat (limited to 'tensorflow/core/kernels/cast_op_impl.h')
-rw-r--r-- | tensorflow/core/kernels/cast_op_impl.h | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h index cb7cc81937..1ee0796ac1 100644 --- a/tensorflow/core/kernels/cast_op_impl.h +++ b/tensorflow/core/kernels/cast_op_impl.h @@ -33,6 +33,16 @@ struct CastFunctor<Eigen::ThreadPoolDevice, O, I> { } }; +#ifdef TENSORFLOW_USE_SYCL +template <typename O, typename I> +struct CastFunctor<Eigen::SyclDevice, O, I> { + void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o, + typename TTypes<I>::ConstFlat i) { + o.device(d) = i.template cast<O>(); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // namespace functor #define CURRY_TYPES3(FN, arg0, arg1) \ @@ -140,6 +150,25 @@ GetGpuCastFromBfloat(DataType dst_dtype); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromBool(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromInt32(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromInt64(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromFloat(DataType dst_dtype); + +std::function<void(OpKernelContext*, const Tensor&, Tensor*)> +GetSyclCastFromDouble(DataType dst_dtype); + +#endif // TENSORFLOW_USE_SYCL + } // namespace tensorflow #endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_ + |