aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/concat_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/concat_op.cc')
-rw-r--r--tensorflow/core/kernels/concat_op.cc50
1 files changed, 50 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/concat_op.cc b/tensorflow/core/kernels/concat_op.cc
index e6dae5fa7e..9628a7efa4 100644
--- a/tensorflow/core/kernels/concat_op.cc
+++ b/tensorflow/core/kernels/concat_op.cc
@@ -35,6 +35,9 @@ typedef Eigen::ThreadPoolDevice CPUDevice;
#if GOOGLE_CUDA
typedef Eigen::GpuDevice GPUDevice;
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+typedef Eigen::SyclDevice SYCLDevice;
+#endif // TENSORFLOW_USE_SYCL
enum AxisArgumentName { NAME_IS_AXIS, NAME_IS_CONCAT_DIM };
@@ -134,6 +137,12 @@ class ConcatBaseOp : public OpKernel {
return;
}
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+ if (std::is_same<Device, SYCLDevice>::value) {
+ ConcatSYCL<T>(c->eigen_sycl_device(), inputs_flat, &output_flat);
+ return;
+ }
+#endif // TENSORFLOW_USE_SYCL
ConcatCPU<T>(c->device(), inputs_flat, &output_flat);
}
}
@@ -207,6 +216,39 @@ REGISTER_KERNEL_BUILDER(Name("ConcatV2")
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("Concat") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("concat_dim"), \
+ ConcatOp<SYCLDevice, type>) \
+ REGISTER_KERNEL_BUILDER(Name("ConcatV2") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<int32>("Tidx") \
+ .HostMemory("axis"), \
+ ConcatV2Op<SYCLDevice, type>)
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_SYCL);
+REGISTER_KERNEL_BUILDER(Name("Concat")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .HostMemory("concat_dim")
+ .HostMemory("values")
+ .HostMemory("output"),
+ ConcatOp<CPUDevice, int32>);
+REGISTER_KERNEL_BUILDER(Name("ConcatV2")
+ .Device(DEVICE_SYCL)
+ .TypeConstraint<int32>("T")
+ .TypeConstraint<int32>("Tidx")
+ .HostMemory("values")
+ .HostMemory("axis")
+ .HostMemory("output"),
+ ConcatV2Op<CPUDevice, int32>);
+#undef REGISTER_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
class ConcatOffsetOp : public OpKernel {
public:
explicit ConcatOffsetOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
@@ -293,4 +335,12 @@ REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
.HostMemory("offset"),
ConcatOffsetOp);
+#ifdef TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("ConcatOffset")
+ .Device(DEVICE_SYCL)
+ .HostMemory("concat_dim")
+ .HostMemory("shape")
+ .HostMemory("offset"),
+ ConcatOffsetOp);
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow