diff options
Diffstat (limited to 'tensorflow/core/kernels/control_flow_ops.cc')
-rw-r--r-- | tensorflow/core/kernels/control_flow_ops.cc | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/tensorflow/core/kernels/control_flow_ops.cc b/tensorflow/core/kernels/control_flow_ops.cc index e92b11efc6..b01263f288 100644 --- a/tensorflow/core/kernels/control_flow_ops.cc +++ b/tensorflow/core/kernels/control_flow_ops.cc @@ -113,9 +113,12 @@ REGISTER_GPU_HOST_REF_KERNEL(string); #undef REGISTER_GPU_HOST_REF_KERNEL #if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Switch").Device(DEVICE_SYCL).TypeConstraint<type>("T"), SwitchOp) +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Switch") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("pred"), \ + SwitchOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL @@ -219,9 +222,12 @@ REGISTER_GPU_REF_KERNEL(bool); #undef REGISTER_GPU_REF_KERNEL #if TENSORFLOW_USE_SYCL -#define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("Merge").Device(DEVICE_SYCL).TypeConstraint<type>("T"), MergeOp) +#define REGISTER_SYCL_KERNEL(type) \ + REGISTER_KERNEL_BUILDER(Name("Merge") \ + .Device(DEVICE_SYCL) \ + .TypeConstraint<type>("T") \ + .HostMemory("value_index"), \ + MergeOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL @@ -418,8 +424,12 @@ REGISTER_GPU_HOST_KERNEL(string); #if TENSORFLOW_USE_SYCL #define REGISTER_SYCL_KERNEL(type) \ - REGISTER_KERNEL_BUILDER( \ - Name("NextIteration").Device(DEVICE_SYCL).TypeConstraint<type>("T"), NextIterationOp) + REGISTER_KERNEL_BUILDER(Name("NextIteration") \ + .Device(DEVICE_SYCL) \ + .HostMemory("data") \ + .HostMemory("output") \ + .TypeConstraint<type>("T"), \ + NextIterationOp) REGISTER_SYCL_KERNEL(bool); TF_CALL_NUMBER_TYPES_NO_INT32(REGISTER_SYCL_KERNEL); #undef REGISTER_SYCL_KERNEL |