diff options
Diffstat (limited to 'tensorflow/core/kernels/dynamic_stitch_op.cc')
-rw-r--r-- | tensorflow/core/kernels/dynamic_stitch_op.cc | 27 |
1 files changed, 13 insertions, 14 deletions
diff --git a/tensorflow/core/kernels/dynamic_stitch_op.cc b/tensorflow/core/kernels/dynamic_stitch_op.cc index 08ae787c86..135d635514 100644 --- a/tensorflow/core/kernels/dynamic_stitch_op.cc +++ b/tensorflow/core/kernels/dynamic_stitch_op.cc @@ -165,20 +165,6 @@ class DynamicStitchOp : public OpKernel { TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); #undef REGISTER_DYNAMIC_STITCH -#ifdef TENSORFLOW_USE_SYCL -#define REGISTER_DYNAMIC_STITCH_SYCL(type) \ - REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ - .Device(DEVICE_SYCL) \ - .TypeConstraint<type>("T") \ - .HostMemory("indices") \ - .HostMemory("data") \ - .HostMemory("merged"), \ - DynamicStitchOp<type>) - -TF_CALL_ALL_TYPES(REGISTER_DYNAMIC_STITCH_SYCL); -#undef REGISTER_DYNAMIC_STITCH_SYCL -#endif // TENSORFLOW_USE_SYCL - #if GOOGLE_CUDA #define REGISTER_DYNAMIC_STITCH_GPU(type) \ REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ @@ -194,4 +180,17 @@ TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_GPU); #endif // GOOGLE_CUDA +#ifdef TENSORFLOW_USE_SYCL +#define REGISTER_DYNAMIC_STITCH_SYCL(type) \ + REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("indices") \ + .HostMemory("data") \ + .HostMemory("merged"), \ + DynamicStitchOp<type>) + +TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH_SYCL); +#undef REGISTER_DYNAMIC_STITCH_SYCL +#endif // TENSORFLOW_USE_SYCL } // namespace tensorflow |