aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/map_stage_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/map_stage_op.cc')
-rw-r--r--tensorflow/core/kernels/map_stage_op.cc15
1 files changed, 10 insertions, 5 deletions
diff --git a/tensorflow/core/kernels/map_stage_op.cc b/tensorflow/core/kernels/map_stage_op.cc
index 46eaf3d9e7..0168b57d35 100644
--- a/tensorflow/core/kernels/map_stage_op.cc
+++ b/tensorflow/core/kernels/map_stage_op.cc
@@ -550,12 +550,17 @@ REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
#endif // GOOGLE_CUDA
#ifdef TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("MapStage").HostMemory("key").Device(DEVICE_SYCL),
+REGISTER_KERNEL_BUILDER(Name("MapStage")
+ .HostMemory("key")
+ .HostMemory("indices")
+ .Device(DEVICE_SYCL),
MapStageOp<false>);
-REGISTER_KERNEL_BUILDER(
- Name("OrderedMapStage").HostMemory("key").Device(DEVICE_SYCL),
- MapStageOp<true>);
-#endif // TENSORFLOW_USE_SYCL
+REGISTER_KERNEL_BUILDER(Name("OrderedMapStage")
+ .HostMemory("key")
+ .HostMemory("indices")
+ .Device(DEVICE_SYCL),
+ MapStageOp<true>);
+#endif // TENSORFLOW_USE_SYCL
template <bool Ordered>
class MapUnstageOp : public OpKernel {