diff options
Diffstat (limited to 'tensorflow/core/kernels/argmax_op.cc')
-rw-r--r-- | tensorflow/core/kernels/argmax_op.cc | 16 |
1 files changed, 9 insertions, 7 deletions
diff --git a/tensorflow/core/kernels/argmax_op.cc b/tensorflow/core/kernels/argmax_op.cc index d78f6a7ff1..071d0e684a 100644 --- a/tensorflow/core/kernels/argmax_op.cc +++ b/tensorflow/core/kernels/argmax_op.cc @@ -57,19 +57,21 @@ class ArgOp : public OpKernel { const int32 dim = internal::SubtleMustCopy(dimension.scalar<int32>()()); const int input_dims = input.dims(); - OP_REQUIRES(context, dim >= 0, errors::InvalidArgument("dim must be >= 0")); - OP_REQUIRES(context, dim < input_dims, - errors::InvalidArgument("Minimum tensor rank: ", dim + 1, - " but got: ", input_dims)); + int axis = dim < 0 ? dim + input_dims : dim; + + OP_REQUIRES(context, axis >= 0 && axis < input_dims, + errors::InvalidArgument( + "Expected dimension in the range [", -input_dims, ", ", + input_dims, "), but got ", dim)); OP_REQUIRES( - context, input.dim_size(dim) > 0, + context, input.dim_size(axis) > 0, errors::InvalidArgument("Reduction axis ", dim, " is empty in shape ", input.shape().DebugString())); TensorShape output_shape; const TensorShape& input_shape = input.shape(); for (int d = 0; d < input_dims - 1; ++d) { - output_shape.AddDim(input_shape.dim_size((d < dim) ? d : d + 1)); + output_shape.AddDim(input_shape.dim_size((d < axis) ? d : d + 1)); } Tensor* output = nullptr; OP_REQUIRES_OK(context, context->allocate_output(0, output_shape, &output)); @@ -77,7 +79,7 @@ class ArgOp : public OpKernel { #define HANDLE_DIM(NDIM) \ case NDIM: \ ArgFunctor::Reduce##NDIM(context->eigen_device<Device>(), \ - input.tensor<T, NDIM>(), dim, \ + input.tensor<T, NDIM>(), axis, \ output->tensor<int64, NDIM - 1>()); \ break; |