diff options
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r-- | tensorflow/core/kernels/unique_op.cc | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc index e64b27b572..0ef8724b10 100644 --- a/tensorflow/core/kernels/unique_op.cc +++ b/tensorflow/core/kernels/unique_op.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/hash/hash.h" @@ -63,8 +64,17 @@ class UniqueOp : public OpKernel { OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()), errors::InvalidArgument("unique expects a 1D vector.")); } else { - auto axis_vec = axis_tensor.vec<int64>(); - axis = axis_vec(0); + OP_REQUIRES(context, + (axis_tensor.dtype() == DT_INT32 || + axis_tensor.dtype() == DT_INT64), + errors::InvalidArgument( + "axis tensor should be int32 or int64, but got ", + axis_tensor.dtype())); + if (axis_tensor.dtype() == DT_INT32) { + axis = internal::SubtleMustCopy(axis_tensor.scalar<int32>()()); + } else { + axis = internal::SubtleMustCopy(axis_tensor.scalar<int64>()()); + } axis = axis < 0 ? axis + input.dims() : axis; OP_REQUIRES(context, 0 <= axis && axis < input.dims(), errors::InvalidArgument("axis has to be between [0, ", |