diff options
Diffstat (limited to 'tensorflow/core/kernels/scatter_nd_op.cc')
-rw-r--r-- | tensorflow/core/kernels/scatter_nd_op.cc | 10 |
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index e1fc2ea128..e0194605ce 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -277,6 +277,9 @@ TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_ADD_SUB_CPU); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_UPDATE_CPU); TF_CALL_NUMBER_TYPES(REGISTER_SCATTER_ND_CPU); TF_CALL_string(REGISTER_SCATTER_ND_CPU); +TF_CALL_bool(REGISTER_SCATTER_ND_ADD_SUB_CPU); +TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_CPU); +TF_CALL_bool(REGISTER_SCATTER_ND_CPU); // Registers GPU kernels. #if GOOGLE_CUDA @@ -309,6 +312,7 @@ TF_CALL_complex128(REGISTER_SCATTER_ND_ALL_GPU); TF_CALL_int32(REGISTER_SCATTER_ND_ADD_SUB_SYCL); TF_CALL_int32(REGISTER_SCATTER_ND_UPDATE_SYCL); +TF_CALL_bool(REGISTER_SCATTER_ND_UPDATE_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_ADD_SUB_SYCL); TF_CALL_GPU_NUMBER_TYPES_NO_HALF(REGISTER_SCATTER_ND_UPDATE_SYCL); #undef REGISTER_SCATTER_ND_ADD_SUB_SYCL @@ -537,11 +541,13 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices, } } if (bad_i >= 0) { + auto slice_shape = indices.shape(); + slice_shape.RemoveLastDims(1); return errors::InvalidArgument( - "Invalid indices: ", SliceDebugString(indices.shape(), bad_i), " = [", + "indices", SliceDebugString(slice_shape, bad_i), " = [", str_util::Join( gtl::ArraySlice<Index>(&indices_flat(bad_i, 0), slice_dim), ", "), - "] does not index into ", shape.DebugString()); + "] does not index into shape ", shape.DebugString()); } return Status::OK(); } |