diff options
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 44 |
1 files changed, 34 insertions, 10 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 14b87f0edf..c5935141f8 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -739,7 +739,7 @@ REGISTER_OP("Diag") .Attr("T: {float, double, int32, int64, complex64, complex128}") .SetShapeFn([](InferenceContext* c) { ShapeHandle in = c->input(0); - TF_RETURN_IF_ERROR(c->WithRankAtMost(in, 3, &in)); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in)); // Output shape is original concatenated with itself. ShapeHandle out; TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out)); @@ -767,7 +767,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0] [0, 0, 0, 4]] ``` -diagonal: Rank k tensor where k is at most 3. +diagonal: Rank k tensor where k is at most 1. )doc"); // -------------------------------------------------------------------------- @@ -783,9 +783,9 @@ REGISTER_OP("DiagPart") } // Rank must be even, and result will have rank <rank/2>. const int32 rank = c->Rank(in); - if ((rank % 2) != 0 || rank > 6) { + if ((rank % 2) != 0 || rank <= 0) { return errors::InvalidArgument( - "Input must have even rank <= 6, input rank is ", rank); + "Input must have even and non-zero rank, input rank is ", rank); } const int32 mid = rank / 2; @@ -820,7 +820,7 @@ For example: tf.diag_part(input) ==> [1, 2, 3, 4] ``` -input: Rank k tensor where k is 2, 4, or 6. +input: Rank k tensor where k is even and not zero. diagonal: The extracted diagonal. )doc"); @@ -1175,7 +1175,7 @@ For example: # [20, 21, 22, 23]]]] # tensor 't' shape is [1, 2, 3, 4] -# 'dims' is [3] or 'dims' is -1 +# 'dims' is [3] or 'dims' is [-1] reverse(t, dims) ==> [[[[ 3, 2, 1, 0], [ 7, 6, 5, 4], [ 11, 10, 9, 8]], @@ -2283,6 +2283,8 @@ size(t) ==> 12 namespace { +// This SliceHelper processes the output shape of the `slice` +// when the tensor of `sizes` is available. template <typename T> Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, const Tensor* sizes_value, @@ -2308,7 +2310,6 @@ Status SliceHelper(InferenceContext* c, ShapeHandle begin_value, return Status::OK(); } - } // namespace // -------------------------------------------------------------------------- @@ -2339,9 +2340,10 @@ REGISTER_OP("Slice") ShapeHandle begin_value; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value)); - // NOTE(mrry): We can't use `MakeShapeFromShapeTensor` for `sizes` because - // it might contain -1, which can't be represented (-1 in the ShapeHandle - // would mean "unknown". + // We check the tensor value here and will only use + // `MakeShapeFromShapeTensor` when `sizes_value` is null. + // The reason is that `sizes`might contain -1, which can't + // be represented (-1 in the ShapeHandle would mean "unknown". const Tensor* sizes_value = c->input_tensor(2); if (sizes_value != nullptr) { @@ -2361,6 +2363,28 @@ REGISTER_OP("Slice") c->set_output(0, c->MakeShape(dims)); return Status::OK(); } else { + // In case `sizes` is not available (`sizes_value` is null), + // we could try to use `MakeShapeFromShapeTensor` here. + // If sizes contain -1, we will simply consider it as `Unknown`. + // This is less than ideal but still an improvement of shape inference. + // The following is an example that returns [None, 1, None] with this + // code path: + // z = tf.zeros((1, 2, 3)) + // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1]) + // m.get_shape().as_list() + ShapeHandle sizes_value; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value)); + if (c->RankKnown(sizes_value)) { + TF_RETURN_IF_ERROR( + c->WithRank(begin_value, c->Rank(sizes_value), &begin_value)); + std::vector<DimensionHandle> dims; + for (int i = 0; i < c->Rank(sizes_value); ++i) { + dims.emplace_back(c->Dim(sizes_value, i)); + } + c->set_output(0, c->MakeShape(dims)); + return Status::OK(); + } + // We might know the rank of the input. if (c->RankKnown(input)) { c->set_output(0, c->UnknownShapeOfRank(c->Rank(input))); |