aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_op.cc
diff options
context:
space:
mode:
authorGravatar Jonathan Hseu <jhseu@google.com>2016-12-22 15:38:30 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-22 15:48:41 -0800
commitbed8383c27a0a7225e6fc7ff59a2cd6388fb4d09 (patch)
treeb70cfc88f95f318195f8610ffb960e98604348d1 /tensorflow/core/kernels/scatter_op.cc
parent1e5bd8cdd62033d1f7ea928fcbec521bb48bb1f5 (diff)
Merge changes from github.
Change: 142805270
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_op.cc17
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc
index 604f753db1..827eb7dbca 100644
--- a/tensorflow/core/kernels/scatter_op.cc
+++ b/tensorflow/core/kernels/scatter_op.cc
@@ -27,6 +27,9 @@ namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
// Check whether updates.shape = indices.shape + params.shape[1:]
static bool ValidShapes(const Tensor& params, const Tensor& updates,
@@ -170,6 +173,20 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU);
#endif // GOOGLE_CUDA
+// Registers GPU kernels.
+#if TENSORFLOW_USE_SYCL
+#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \
+ REGISTER_SCATTER_ARITHEMTIC(type, SYCL);
+
+#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL);
+
+REGISTER_SCATTER_ARITHEMTIC_SYCL(float);
+REGISTER_SCATTER_UPDATE_SYCL(float);
+
+#undef REGISTER_SCATTER_ARITHEMTIC_SYCL
+#undef REGISTER_SCATTER_UPDATE_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#undef REGISTER_SCATTER_ARITHEMTIC
#undef REGISTER_SCATTER_ARITHEMTIC_CPU
#undef REGISTER_SCATTER_ARITHEMTIC_GPU