diff options
Diffstat (limited to 'tensorflow/core/kernels/inplace_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/inplace_ops.cc | 62 |
1 files changed, 59 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc index 4433b9eea9..67bec7d50e 100644 --- a/tensorflow/core/kernels/inplace_ops.cc +++ b/tensorflow/core/kernels/inplace_ops.cc @@ -25,11 +25,14 @@ limitations under the License. namespace tensorflow { typedef Eigen::ThreadPoolDevice CPUDevice; +#ifdef TENSORFLOW_USE_SYCL +typedef Eigen::SyclDevice SyclDevice; +#endif // TENSORFLOW_USE_SYCL namespace functor { -template <typename T> -Status DoParallelConcatUpdate(const CPUDevice& d, const Tensor& value, +template <typename Device, typename T> +Status DoParallelConcatUpdate(const Device& d, const Tensor& value, int32 loc, Tensor* output) { auto Tvalue = value.flat_outer_dims<T>(); auto Toutput = output->flat_outer_dims<T>(); @@ -46,7 +49,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, switch (value.dtype()) { #define CASE(type) \ case DataTypeToEnum<type>::value: \ - return DoParallelConcatUpdate<type>(d, value, loc, output); + return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output); TF_CALL_NUMBER_TYPES(CASE); TF_CALL_string(CASE); #undef CASE @@ -55,6 +58,23 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc, } } +#ifdef TENSORFLOW_USE_SYCL +template <> +Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc, + Tensor* output) { + CHECK_EQ(value.dtype(), output->dtype()); + switch (value.dtype()) { +#define CASE(type) \ + case DataTypeToEnum<type>::value: \ + return DoParallelConcatUpdate<SyclDevice, type>(d, value, loc, output); + TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE); +#undef CASE + default: + return errors::InvalidArgument("Unsupported data type: ", value.dtype()); + } +} +#endif // TENSORFLOW_USE_SYCL + } // end namespace functor namespace { @@ -152,6 +172,42 @@ TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY) TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT); #undef REGISTER_PARALLEL_CONCAT +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_EMPTY(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("dtype"), \ + ParallelConcatStart<SyclDevice, type>); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY) +#undef REGISTER_EMPTY + +#define REGISTER_PARALLEL_CONCAT(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \ + FailureKernel); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT); +#undef REGISTER_PARALLEL_CONCAT + +#define REGISTER(type) \ + REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T"), \ + ParallelConcatUpdate<SyclDevice>); +TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER) +#undef REGISTER + +// Register versions that operate on int32 data on the CPU even though the op +// has been placed on the SYCL + +REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") + .Device(DEVICE_SYCL) + .HostMemory("value") + .HostMemory("update") + .HostMemory("output") + .TypeConstraint<int32>("T"), + ParallelConcatUpdate<CPUDevice>); +#endif // TENSORFLOW_USE_SYCL + #if GOOGLE_CUDA typedef Eigen::GpuDevice GPUDevice; |