aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc26
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 88d2aa3f41..111670c361 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -794,11 +794,35 @@ REGISTER_OP("ReverseV2")
ShapeHandle input = c->input(0);
ShapeHandle axis;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis));
- // TODO(aselle): if input(0)'s dimension is known we could validate axis
if (c->Rank(input) > 8) {
return errors::InvalidArgument(
"reverse does not work on tensors with more than 8 dimensions");
}
+ const Tensor* axis_tensor = c->input_tensor(1);
+ if (axis_tensor != nullptr && c->RankKnown(input)) {
+ int32 rank = c->Rank(input);
+ std::vector<int64> axis_value;
+ if (axis_tensor->dtype() == DT_INT32) {
+ axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements());
+ } else {
+ axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements());
+ }
+ std::vector<bool> axes_dense(c->Rank(input), false);
+ for (int i = 0; i < axis_value.size(); i++) {
+ int64 canonical_axis =
+ axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i];
+ if (canonical_axis < 0 || canonical_axis >= rank) {
+ return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i],
+ " is out of valid range [", 0, ", ",
+ rank - 1);
+ }
+ if (axes_dense[canonical_axis]) {
+ return errors::InvalidArgument("axis ", canonical_axis,
+ " specified more than once.");
+ }
+ axes_dense[canonical_axis] = true;
+ }
+ }
c->set_output(0, input);
return Status::OK();
});