aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/constant_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/constant_op.cc')
-rw-r--r--tensorflow/core/kernels/constant_op.cc37
1 files changed, 34 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/constant_op.cc b/tensorflow/core/kernels/constant_op.cc
index 362abd4a1f..1ae290ec4b 100644
--- a/tensorflow/core/kernels/constant_op.cc
+++ b/tensorflow/core/kernels/constant_op.cc
@@ -16,9 +16,6 @@ limitations under the License.
// See docs in ../ops/array_ops.cc.
#define EIGEN_USE_THREADS
-#if TENSORFLOW_USE_SYCL
-#define EIGEN_USE_SYCL
-#endif
#include "tensorflow/core/kernels/constant_op.h"
@@ -116,6 +113,9 @@ REGISTER_KERNEL_BUILDER(Name("Const")
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif //TENSORFLOW_USE_SYCL
namespace functor {
@@ -128,6 +128,17 @@ struct FillFunctor<CPUDevice, T> {
}
};
+#ifdef TENSORFLOW_USE_SYCL
+// Partial specialization of FillFunctor<Device=SYCLDevice, T>.
+template <typename T>
+struct FillFunctor<SYCLDevice, T> {
+ void operator()(const SYCLDevice& d, typename TTypes<T>::Flat out,
+ typename TTypes<T>::ConstScalar in) {
+ To32Bit(out).device(d) = To32Bit(out).constant(in());
+ }
+};
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
template <typename Device, typename T>
@@ -172,6 +183,17 @@ TF_CALL_ALL_TYPES(REGISTER_CPU_KERNEL);
REGISTER_KERNEL(CPU, quint8);
#undef REGISTER_CPU_KERNEL
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL(SYCL, float)
+REGISTER_KERNEL_BUILDER(Name("Fill")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("dims")
+ .HostMemory("value")
+ .HostMemory("output"),
+ FillOp<CPUDevice, int32>);
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
REGISTER_KERNEL(GPU, Eigen::half);
REGISTER_KERNEL(GPU, float);
@@ -220,6 +242,15 @@ class ZerosLikeOp : public OpKernel {
TF_CALL_POD_STRING_TYPES(REGISTER_CPU);
#undef REGISTER_CPU
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL(float, SYCL);
+REGISTER_KERNEL_BUILDER(Name("ZerosLike")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("y"),
+ ZerosLikeOp<CPUDevice, int32>);
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
REGISTER_KERNEL(bool, GPU);
REGISTER_KERNEL(Eigen::half, GPU);