aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dynamic_stitch_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc')
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc14
1 files changed, 14 insertions, 0 deletions
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index ae883ea535..bff1914682 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -165,6 +165,20 @@ class DynamicStitchOp : public OpKernel {
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices") \
+ .HostMemory("data") \
+ .HostMemory("merged"), \
+ DynamicStitchOp<type>)
+
+TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
+#undef REGISTER_DYNAMIC_STITCH_SYCL
+#endif // TENSORFLOW_USE_SYCL
+
#if GOOGLE_CUDA
#define REGISTER_DYNAMIC_STITCH_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \