aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/inplace_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/inplace_ops.cc')
-rw-r--r--tensorflow/core/kernels/inplace_ops.cc62
1 files changed, 59 insertions, 3 deletions
diff --git a/tensorflow/core/kernels/inplace_ops.cc b/tensorflow/core/kernels/inplace_ops.cc
index 4433b9eea9..67bec7d50e 100644
--- a/tensorflow/core/kernels/inplace_ops.cc
+++ b/tensorflow/core/kernels/inplace_ops.cc
@@ -25,11 +25,14 @@ limitations under the License.
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SyclDevice;
+#endif // TENSORFLOW_USE_SYCL
namespace functor {
-template <typename T>
-Status DoParallelConcatUpdate(const CPUDevice& d, const Tensor& value,
+template <typename Device, typename T>
+Status DoParallelConcatUpdate(const Device& d, const Tensor& value,
int32 loc, Tensor* output) {
auto Tvalue = value.flat_outer_dims<T>();
auto Toutput = output->flat_outer_dims<T>();
@@ -46,7 +49,7 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
switch (value.dtype()) {
#define CASE(type) \
case DataTypeToEnum<type>::value: \
- return DoParallelConcatUpdate<type>(d, value, loc, output);
+ return DoParallelConcatUpdate<CPUDevice, type>(d, value, loc, output);
TF_CALL_NUMBER_TYPES(CASE);
TF_CALL_string(CASE);
#undef CASE
@@ -55,6 +58,23 @@ Status DoParallelConcat(const CPUDevice& d, const Tensor& value, int32 loc,
}
}
+#ifdef TENSORFLOW_USE_SYCL
+template <>
+Status DoParallelConcat(const SyclDevice& d, const Tensor& value, int32 loc,
+ Tensor* output) {
+ CHECK_EQ(value.dtype(), output->dtype());
+ switch (value.dtype()) {
+#define CASE(type) \
+ case DataTypeToEnum<type>::value: \
+ return DoParallelConcatUpdate<SyclDevice, type>(d, value, loc, output);
+ TF_CALL_GPU_NUMBER_TYPES_NO_HALF(CASE);
+#undef CASE
+ default:
+ return errors::InvalidArgument("Unsupported data type: ", value.dtype());
+ }
+}
+#endif // TENSORFLOW_USE_SYCL
+
} // end namespace functor
namespace {
@@ -152,6 +172,42 @@ TF_CALL_POD_STRING_TYPES(REGISTER_EMPTY)
TF_CALL_POD_STRING_TYPES(REGISTER_PARALLEL_CONCAT);
#undef REGISTER_PARALLEL_CONCAT
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_EMPTY(type) \
+ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatStart") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("dtype"), \
+ ParallelConcatStart<SyclDevice, type>);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_EMPTY)
+#undef REGISTER_EMPTY
+
+#define REGISTER_PARALLEL_CONCAT(type) \
+ REGISTER_KERNEL_BUILDER( \
+ Name("ParallelConcat").Device(DEVICE_SYCL).TypeConstraint<type>("T"), \
+ FailureKernel);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_PARALLEL_CONCAT);
+#undef REGISTER_PARALLEL_CONCAT
+
+#define REGISTER(type) \
+ REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T"), \
+ ParallelConcatUpdate<SyclDevice>);
+TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER)
+#undef REGISTER
+
+// Register versions that operate on int32 data on the CPU even though the op
+// has been placed on the SYCL
+
+REGISTER_KERNEL_BUILDER(Name("_ParallelConcatUpdate")
+ .Device(DEVICE_SYCL)
+ .HostMemory("value")
+ .HostMemory("update")
+ .HostMemory("output")
+ .TypeConstraint<int32>("T"),
+ ParallelConcatUpdate<CPUDevice>);
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;