diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 88d2aa3f41..111670c361 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -794,11 +794,35 @@ REGISTER_OP("ReverseV2") ShapeHandle input = c->input(0); ShapeHandle axis; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &axis)); - // TODO(aselle): if input(0)'s dimension is known we could validate axis if (c->Rank(input) > 8) { return errors::InvalidArgument( "reverse does not work on tensors with more than 8 dimensions"); } + const Tensor* axis_tensor = c->input_tensor(1); + if (axis_tensor != nullptr && c->RankKnown(input)) { + int32 rank = c->Rank(input); + std::vector<int64> axis_value; + if (axis_tensor->dtype() == DT_INT32) { + axis_value = AsInt64<int32>(axis_tensor, axis_tensor->NumElements()); + } else { + axis_value = AsInt64<int64>(axis_tensor, axis_tensor->NumElements()); + } + std::vector<bool> axes_dense(c->Rank(input), false); + for (int i = 0; i < axis_value.size(); i++) { + int64 canonical_axis = + axis_value[i] < 0 ? rank + axis_value[i] : axis_value[i]; + if (canonical_axis < 0 || canonical_axis >= rank) { + return errors::InvalidArgument("'axis'[", i, "] = ", axis_value[i], + " is out of valid range [", 0, ", ", + rank - 1); + } + if (axes_dense[canonical_axis]) { + return errors::InvalidArgument("axis ", canonical_axis, + " specified more than once."); + } + axes_dense[canonical_axis] = true; + } + } c->set_output(0, input); return Status::OK(); }); |