aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/array_ops.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2018-03-27 12:09:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-27 12:12:24 -0700
commit5da1cdcf0032f63c22afb41a460fd44c52ada048 (patch)
tree3a4b1c8224191cb5bf4f9f08b8ed8f5f07a768a0 /tensorflow/core/ops/array_ops.cc
parentfd77211de17bf053cc8f5a82c8eff1818451120c (diff)
Improved shape inference for reshape
PiperOrigin-RevId: 190651873
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r--tensorflow/core/ops/array_ops.cc104
1 files changed, 73 insertions, 31 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 39b92464cb..88d2aa3f41 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -178,46 +178,88 @@ Status SetOutputShapeForReshape(InferenceContext* c) {
c->set_output(0, out);
return Status::OK();
}
- DimensionHandle num_in_elems = c->NumElements(in);
- if (c->FullyDefined(out)) {
- DimensionHandle num_out_elems = c->NumElements(out);
- if (c->ValueKnown(num_in_elems) &&
- c->Value(num_in_elems) != c->Value(num_out_elems)) {
- return errors::InvalidArgument(
- "Cannot reshape a tensor with ", c->DebugString(num_in_elems),
- " elements to shape ", c->DebugString(out), " (",
- c->DebugString(num_out_elems), " elements)");
- }
- c->set_output(0, out);
- return Status::OK();
- }
- if (c->ValueKnown(num_in_elems)) {
+ if (c->RankKnown(out) && c->RankKnown(in)) {
// We don't know the number of output elements, but we can try to infer
// the missing dimension.
- int32 unknown_idx = -1;
bool too_many_unknown = false;
- DimensionHandle known_elems = c->MakeDim(1);
- for (int32 i = 0; i < c->Rank(out); ++i) {
- DimensionHandle dim = c->Dim(out, i);
- if (!c->ValueKnown(dim)) {
- if (unknown_idx >= 0) {
- too_many_unknown = true;
- break;
+ int32 out_unknown_idx = -1;
+
+ DimensionHandle known_out_elems = c->NumElements(out);
+ if (!c->ValueKnown(known_out_elems)) {
+ known_out_elems = c->MakeDim(1);
+ for (int32 i = 0; i < c->Rank(out); ++i) {
+ DimensionHandle dim = c->Dim(out, i);
+ if (!c->ValueKnown(dim)) {
+ if (out_unknown_idx >= 0) {
+ too_many_unknown = true;
+ break;
+ }
+ out_unknown_idx = i;
+ } else {
+ TF_RETURN_IF_ERROR(
+ c->Multiply(known_out_elems, dim, &known_out_elems));
}
- unknown_idx = i;
- } else {
- TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems));
}
}
- if (!too_many_unknown && c->Value(known_elems) != 0) {
- DimensionHandle inferred_dim;
- TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems),
- true /* evenly_divisible */, &inferred_dim));
- TF_RETURN_IF_ERROR(c->ReplaceDim(out, unknown_idx, inferred_dim, &out));
+ int32 in_unknown_idx = -1;
+ DimensionHandle known_in_elems = c->NumElements(in);
+ if (!c->ValueKnown(known_in_elems)) {
+ known_in_elems = c->MakeDim(1);
+ for (int32 i = 0; i < c->Rank(in); ++i) {
+ DimensionHandle dim = c->Dim(in, i);
+ if (!c->ValueKnown(dim)) {
+ if (in_unknown_idx >= 0) {
+ too_many_unknown = true;
+ break;
+ }
+ in_unknown_idx = i;
+ } else {
+ TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems));
+ }
+ }
}
- }
+ if (!too_many_unknown) {
+ if (in_unknown_idx < 0 && out_unknown_idx < 0) {
+ // Just check that the dimensions match.
+ if (c->Value(known_in_elems) != c->Value(known_out_elems)) {
+ return errors::InvalidArgument(
+ "Cannot reshape a tensor with ", c->DebugString(known_in_elems),
+ " elements to shape ", c->DebugString(out), " (",
+ c->DebugString(known_out_elems), " elements)");
+ }
+ } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 &&
+ c->Value(known_out_elems) > 0) {
+ // Input fully known, infer the one missing output dim
+ DimensionHandle inferred_dim;
+ TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems),
+ true /* evenly_divisible */,
+ &inferred_dim));
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out));
+
+ } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 &&
+ c->Value(known_in_elems) != 0) {
+ // Output fully known, infer the one missing input dim
+ DimensionHandle inferred_dim;
+ TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems),
+ true /* evenly_divisible */,
+ &inferred_dim));
+ DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+ TF_RETURN_IF_ERROR(
+ c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim));
+ } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) {
+ // Exactly one unknown dimension in both input and output. These 2 are
+ // equal iff the known elements are equal.
+ if (c->Value(known_in_elems) == c->Value(known_out_elems)) {
+ DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx);
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out));
+ }
+ }
+ }
+ }
c->set_output(0, out);
return Status::OK();
}