aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/dynamic_stitch_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc')
-rw-r--r--tensorflow/core/kernels/dynamic_stitch_op.cc27
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc
index 08ae787c86..135d635514 100644
--- a/tensorflow/core/kernels/dynamic_stitch_op.cc
+++ b/tensorflow/core/kernels/dynamic_stitch_op.cc
@@ -165,20 +165,6 @@ class DynamicStitchOp : public OpKernel {
TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
#undef REGISTER_DYNAMIC_STITCH
-#ifdef TENSORFLOW_USE_SYCL
-#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
- REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
- .Device(DEVICE_SYCL) \
- .TypeConstraint<type>("T") \
- .HostMemory("indices") \
- .HostMemory("data") \
- .HostMemory("merged"), \
- DynamicStitchOp<type>)
-
-TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
-#undef REGISTER_DYNAMIC_STITCH_SYCL
-#endif // TENSORFLOW_USE_SYCL
-
#if GOOGLE_CUDA
#define REGISTER_DYNAMIC_STITCH_GPU(type) \
REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
@@ -194,4 +180,17 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
#endif // GOOGLE_CUDA
+#ifdef TENSORFLOW_USE_SYCL
+#define REGISTER_DYNAMIC_STITCH_SYCL(type) \
+ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
+ .Device(DEVICE_SYCL) \
+ .TypeConstraint<type>("T") \
+ .HostMemory("indices") \
+ .HostMemory("data") \
+ .HostMemory("merged"), \
+ DynamicStitchOp<type>)
+
+TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL);
+#undef REGISTER_DYNAMIC_STITCH_SYCL
+#endif // TENSORFLOW_USE_SYCL
} // namespace tensorflow