aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc44
1 files changed, 34 insertions, 10 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 14b87f0edf..c5935141f8 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -739,7 +739,7 @@ REGISTER_OP("Diag")
.Attr("T: {float, double, int32, int64, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
- TF_RETURN_IF_ERROR(c->WithRankAtMost(in, 3, &in));
+ TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
// Output shape is original concatenated with itself.
ShapeHandle out;
TF_RETURN_IF_ERROR(c->Concatenate(in, in, &out));
@@ -767,7 +767,7 @@ tf.diag(diagonal) ==> [[1, 0, 0, 0]
[0, 0, 0, 4]]
```
-diagonal: Rank k tensor where k is at most 3.
+diagonal: Rank k tensor where k is at most 1.
)doc");
// --------------------------------------------------------------------------
@@ -783,9 +783,9 @@ REGISTER_OP("DiagPart")
}
// Rank must be even, and result will have rank <rank/2>.
const int32 rank = c->Rank(in);
- if ((rank % 2) != 0 || rank > 6) {
+ if ((rank % 2) != 0 || rank <= 0) {
return errors::InvalidArgument(
- "Input must have even rank <= 6, input rank is ", rank);
+ "Input must have even and non-zero rank, input rank is ", rank);
}
const int32 mid = rank / 2;
@@ -820,7 +820,7 @@ For example:
tf.diag_part(input) ==> [1, 2, 3, 4]
```
-input: Rank k tensor where k is 2, 4, or 6.
+input: Rank k tensor where k is even and not zero.
diagonal: The extracted diagonal.
)doc");
@@ -1175,7 +1175,7 @@ For example:
# [20, 21, 22, 23]]]]
# tensor 't' shape is [1, 2, 3, 4]
-# 'dims' is [3] or 'dims' is -1
+# 'dims' is [3] or 'dims' is [-1]
reverse(t, dims) ==> [[[[ 3, 2, 1, 0],
[ 7, 6, 5, 4],
[ 11, 10, 9, 8]],
@@ -2283,6 +2283,8 @@ size(t) ==> 12
namespace {
+// This SliceHelper processes the output shape of the `slice`
+// when the tensor of `sizes` is available.
template <typename T>
Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
const Tensor* sizes_value,
@@ -2308,7 +2310,6 @@ Status SliceHelper(InferenceContext* c, ShapeHandle begin_value,
return Status::OK();
}
-
} // namespace
// --------------------------------------------------------------------------
@@ -2339,9 +2340,10 @@ REGISTER_OP("Slice")
ShapeHandle begin_value;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &begin_value));
- // NOTE(mrry): We can't use `MakeShapeFromShapeTensor` for `sizes` because
- // it might contain -1, which can't be represented (-1 in the ShapeHandle
- // would mean "unknown".
+ // We check the tensor value here and will only use
+ // `MakeShapeFromShapeTensor` when `sizes_value` is null.
+ // The reason is that `sizes`might contain -1, which can't
+ // be represented (-1 in the ShapeHandle would mean "unknown".
const Tensor* sizes_value = c->input_tensor(2);
if (sizes_value != nullptr) {
@@ -2361,6 +2363,28 @@ REGISTER_OP("Slice")
c->set_output(0, c->MakeShape(dims));
return Status::OK();
} else {
+ // In case `sizes` is not available (`sizes_value` is null),
+ // we could try to use `MakeShapeFromShapeTensor` here.
+ // If sizes contain -1, we will simply consider it as `Unknown`.
+ // This is less than ideal but still an improvement of shape inference.
+ // The following is an example that returns [None, 1, None] with this
+ // code path:
+ // z = tf.zeros((1, 2, 3))
+ // m = tf.slice(z, [0, 0, 0], [tf.constant(1) + 0, 1, -1])
+ // m.get_shape().as_list()
+ ShapeHandle sizes_value;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(2, &sizes_value));
+ if (c->RankKnown(sizes_value)) {
+ TF_RETURN_IF_ERROR(
+ c->WithRank(begin_value, c->Rank(sizes_value), &begin_value));
+ std::vector<DimensionHandle> dims;
+ for (int i = 0; i < c->Rank(sizes_value); ++i) {
+ dims.emplace_back(c->Dim(sizes_value, i));
+ }
+ c->set_output(0, c->MakeShape(dims));
+ return Status::OK();
+ }
+
// We might know the rank of the input.
if (c->RankKnown(input)) {
c->set_output(0, c->UnknownShapeOfRank(c->Rank(input)));