aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/cast_op_impl.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/cast_op_impl.h')
-rw-r--r--tensorflow/core/kernels/cast_op_impl.h29
1 files changed, 29 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/cast_op_impl.h b/tensorflow/core/kernels/cast_op_impl.h
index cb7cc81937..1ee0796ac1 100644
--- a/tensorflow/core/kernels/cast_op_impl.h
+++ b/tensorflow/core/kernels/cast_op_impl.h
@@ -33,6 +33,16 @@ struct CastFunctor<Eigen::ThreadPoolDevice, O, I> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+template <typename O, typename I>
+struct CastFunctor<Eigen::SyclDevice, O, I> {
+ void operator()(const Eigen::SyclDevice& d, typename TTypes<O>::Flat o,
+ typename TTypes<I>::ConstFlat i) {
+ o.device(d) = i.template cast<O>();
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace functor
#define CURRY_TYPES3(FN, arg0, arg1) \
@@ -140,6 +150,25 @@ GetGpuCastFromBfloat(DataType dst_dtype);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromBool(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt32(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromInt64(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromFloat(DataType dst_dtype);
+
+std::function<void(OpKernelContext*, const Tensor&, Tensor*)>
+GetSyclCastFromDouble(DataType dst_dtype);
+
+#endif // TENSORFLOW_USE_SYCL
+
} // namespace tensorflow
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CAST_OP_IMPL_H_
+