aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/scatter_nd_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index e1fc2ea128..e0194605ce 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -277,6 +277,9 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU);
TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU);
TF_CALL_string(REGISTER_SCATTER_ND_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU);
+TF_CALL_bool(REGISTER_SCATTER_ND_CPU);
// Registers GPU kernels.
#if GOOGLE_CUDA
@@ -309,6 +312,7 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU);
TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL);
+TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL);
TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL);
#undef REGISTER_SCATTER_ND_ADD_SUB_SYCL
@@ -537,11 +541,13 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
}
}
if (bad_i >= 0) {
+ auto slice_shape = indices.shape();
+ slice_shape.RemoveLastDims(1);
return errors::InvalidArgument(
- "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [",
+ "indices", SliceDebugString(slice_shape, bad_i), " = [",
str_util::Join(
gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "),
- "] does not index into ", shape.DebugString());
+ "] does not index into shape ", shape.DebugString());
}
return Status::OK();
}