diff options
Diffstat (limited to 'tensorflow/core/kernels/constant_op.cc')
-rw-r--r-- | tensorflow/core/kernels/constant_op.cc | 37 |
1 files changed, 34 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc index 362abd4a1f..1ae290ec4b 100644 --- a/tensorflow/core/kernels/constant_op.cc +++ b/tensorflow/core/kernels/constant_op.cc @@ -16,9 +16,6 @@ limitations under the License. // See docs in ../ops/array_ops.cc. #define EIGEN_USE_THREADS -#if TENSORFLOW_USE_SYCL -#define EIGEN_USE_SYCL -#endif #include "tensorflow/core/kernels/constant_op.h" @@ -116,6 +113,9 @@ REGISTER_KERNEL_BUILDER(Name("Const") typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif //TENSORFLOW_USE_SYCL namespace functor { @@ -128,6 +128,17 @@ struct FillFunctor<CPUDevice, T> { } }; +#ifdef TENSORFLOW_USE_SYCL +// Partial specialization of FillFunctor<Device=SYCLDevice, T>. +template <typename T> +struct FillFunctor<SYCLDevice, T> { + void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out, + typename TTypes<T>::ConstScalar in) { + To32Bit(out).device(d) = To32Bit(out).constant(in()); + } +}; +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor template <typename Device, typename T> @@ -172,6 +183,17 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL); REGISTER_KERNEL(CPU, quint8); #undef REGISTER_CPU_KERNEL +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL(SYCL, float) +REGISTER_KERNEL_BUILDER(Name("Fill") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .HostMemory("dims") + .HostMemory("value") + .HostMemory("output"), + FillOp<CPUDevice, int32>); +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA REGISTER_KERNEL(GPU, Eigen::half); REGISTER_KERNEL(GPU, float); @@ -220,6 +242,15 @@ class ZerosLikeOp : public OpKernel { TF_CALL_POD_STRING_TYPES(REGISTER_CPU); #undef REGISTER_CPU +#ifdef TENSORFLOW_USE_SYCL +REGISTER_KERNEL(float, SYCL); +REGISTER_KERNEL_BUILDER(Name("ZerosLike") + .Device(DEVICE_SYCL) + .TypeConstraint<int32>("T") + .HostMemory("y"), + ZerosLikeOp<CPUDevice, int32>); +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA REGISTER_KERNEL(bool, GPU); REGISTER_KERNEL(Eigen::half, GPU); |