aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/argmax_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/argmax_op.cc')
-rw-r--r--tensorflow/core/kernels/argmax_op.cc16
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc
index d78f6a7ff1..071d0e684a 100644
--- a/tensorflow/core/kernels/argmax_op.cc
+++ b/tensorflow/core/kernels/argmax_op.cc
@@ -57,19 +57,21 @@ class ArgOp : public OpKernel {
const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()());
const int input_dims = input.dims();
- OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0"));
- OP_REQUIRES(context, dim < input_dims,
- errors::InvalidArgument("Minimum tensor rank: ", dim + 1,
- " but got: ", input_dims));
+ int axis = dim < 0 ? dim + input_dims : dim;
+
+ OP_REQUIRES(context, axis >= 0 && axis < input_dims,
+ errors::InvalidArgument(
+ "Expected dimension in the range [", -input_dims, ", ",
+ input_dims, "), but got ", dim));
OP_REQUIRES(
- context, input.dim_size(dim) > 0,
+ context, input.dim_size(axis) > 0,
errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ",
input.shape().DebugString()));
TensorShape output_shape;
const TensorShape& input_shape = input.shape();
for (int d = 0; d < input_dims - 1; ++d) {
- output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1));
+ output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1));
}
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output));
@@ -77,7 +79,7 @@ class ArgOp : public OpKernel {
#define HANDLE_DIM(NDIM) \
case NDIM: \
ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \
- input.tensor<T, NDIM>(), dim, \
+ input.tensor<T, NDIM>(), axis, \
output->tensor<int64, NDIM - 1>()); \
break;