aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/unique_op.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/unique_op.cc')
-rw-r--r--tensorflow/core/kernels/unique_op.cc14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/core/kernels/unique_op.cc b/tensorflow/core/kernels/unique_op.cc
index e64b27b572..0ef8724b10 100644
--- a/tensorflow/core/kernels/unique_op.cc
+++ b/tensorflow/core/kernels/unique_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/hash/hash.h"
@@ -63,8 +64,17 @@ class UniqueOp : public OpKernel {
OP_REQUIRES(context, TensorShapeUtils::IsVector(input.shape()),
errors::InvalidArgument("unique expects a 1D vector."));
} else {
- auto axis_vec = axis_tensor.vec<int64>();
- axis = axis_vec(0);
+ OP_REQUIRES(context,
+ (axis_tensor.dtype() == DT_INT32 ||
+ axis_tensor.dtype() == DT_INT64),
+ errors::InvalidArgument(
+ "axis tensor should be int32 or int64, but got ",
+ axis_tensor.dtype()));
+ if (axis_tensor.dtype() == DT_INT32) {
+ axis = internal::SubtleMustCopy(axis_tensor.scalar<int32>()());
+ } else {
+ axis = internal::SubtleMustCopy(axis_tensor.scalar<int64>()());
+ }
axis = axis < 0 ? axis + input.dims() : axis;
OP_REQUIRES(context, 0 <= axis && axis < input.dims(),
errors::InvalidArgument("axis has to be between [0, ",