diff options
author | 2016-12-22 15:38:30 -0800 | |
---|---|---|
committer | 2016-12-22 15:48:41 -0800 | |
commit | bed8383c27a0a7225e6fc7ff59a2cd6388fb4d09 (patch) | |
tree | b70cfc88f95f318195f8610ffb960e98604348d1 /tensorflow/core/kernels/scatter_op.cc | |
parent | 1e5bd8cdd62033d1f7ea928fcbec521bb48bb1f5 (diff) |
Merge changes from github.
Change: 142805270
Diffstat (limited to 'tensorflow/core/kernels/scatter_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_op.cc | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/scatter_op.cc b/tensorflow/core/kernels/scatter_op.cc index 604f753db1..827eb7dbca 100644 --- a/tensorflow/core/kernels/scatter_op.cc +++ b/tensorflow/core/kernels/scatter_op.cc @@ -27,6 +27,9 @@ namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::GpuDevice GPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SYCLDevice; +#endif // TENSORFLOW_USE_SYCL // Check whether updates.shape = indices.shape + params.shape[1:] static bool ValidShapes(const Tensor& params, const Tensor& updates, @@ -170,6 +173,20 @@ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_UPDATE_GPU); #endif // GOOGLE_CUDA +// Registers GPU kernels. +#if TENSORFLOW_USE_SYCL +#define REGISTER_SCATTER_ARITHEMTIC_SYCL(type) \ + REGISTER_SCATTER_ARITHEMTIC(type, SYCL); + +#define REGISTER_SCATTER_UPDATE_SYCL(type) REGISTER_SCATTER_UPDATE(type, SYCL); + +REGISTER_SCATTER_ARITHEMTIC_SYCL(float); +REGISTER_SCATTER_UPDATE_SYCL(float); + +#undef REGISTER_SCATTER_ARITHEMTIC_SYCL +#undef REGISTER_SCATTER_UPDATE_SYCL +#endif // TENSORFLOW_USE_SYCL + #undef REGISTER_SCATTER_ARITHEMTIC #undef REGISTER_SCATTER_ARITHEMTIC_CPU #undef REGISTER_SCATTER_ARITHEMTIC_GPU |