diff options
author | 2018-03-27 12:09:59 -0700 | |
---|---|---|
committer | 2018-03-27 12:12:24 -0700 | |
commit | 5da1cdcf0032f63c22afb41a460fd44c52ada048 (patch) | |
tree | 3a4b1c8224191cb5bf4f9f08b8ed8f5f07a768a0 /tensorflow/core/ops/array_ops.cc | |
parent | fd77211de17bf053cc8f5a82c8eff1818451120c (diff) |
Improved shape inference for reshape
PiperOrigin-RevId: 190651873
Diffstat (limited to 'tensorflow/core/ops/array_ops.cc')
-rw-r--r-- | tensorflow/core/ops/array_ops.cc | 104 |
1 files changed, 73 insertions, 31 deletions
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc index 39b92464cb..88d2aa3f41 100644 --- a/tensorflow/core/ops/array_ops.cc +++ b/tensorflow/core/ops/array_ops.cc @@ -178,46 +178,88 @@ Status SetOutputShapeForReshape(InferenceContext* c) { c->set_output(0, out); return Status::OK(); } - DimensionHandle num_in_elems = c->NumElements(in); - if (c->FullyDefined(out)) { - DimensionHandle num_out_elems = c->NumElements(out); - if (c->ValueKnown(num_in_elems) && - c->Value(num_in_elems) != c->Value(num_out_elems)) { - return errors::InvalidArgument( - "Cannot reshape a tensor with ", c->DebugString(num_in_elems), - " elements to shape ", c->DebugString(out), " (", - c->DebugString(num_out_elems), " elements)"); - } - c->set_output(0, out); - return Status::OK(); - } - if (c->ValueKnown(num_in_elems)) { + if (c->RankKnown(out) && c->RankKnown(in)) { // We don't know the number of output elements, but we can try to infer // the missing dimension. - int32 unknown_idx = -1; bool too_many_unknown = false; - DimensionHandle known_elems = c->MakeDim(1); - for (int32 i = 0; i < c->Rank(out); ++i) { - DimensionHandle dim = c->Dim(out, i); - if (!c->ValueKnown(dim)) { - if (unknown_idx >= 0) { - too_many_unknown = true; - break; + int32 out_unknown_idx = -1; + + DimensionHandle known_out_elems = c->NumElements(out); + if (!c->ValueKnown(known_out_elems)) { + known_out_elems = c->MakeDim(1); + for (int32 i = 0; i < c->Rank(out); ++i) { + DimensionHandle dim = c->Dim(out, i); + if (!c->ValueKnown(dim)) { + if (out_unknown_idx >= 0) { + too_many_unknown = true; + break; + } + out_unknown_idx = i; + } else { + TF_RETURN_IF_ERROR( + c->Multiply(known_out_elems, dim, &known_out_elems)); } - unknown_idx = i; - } else { - TF_RETURN_IF_ERROR(c->Multiply(known_elems, dim, &known_elems)); } } - if (!too_many_unknown && c->Value(known_elems) != 0) { - DimensionHandle inferred_dim; - TF_RETURN_IF_ERROR(c->Divide(num_in_elems, c->Value(known_elems), - true /* evenly_divisible */, &inferred_dim)); - TF_RETURN_IF_ERROR(c->ReplaceDim(out, unknown_idx, inferred_dim, &out)); + int32 in_unknown_idx = -1; + DimensionHandle known_in_elems = c->NumElements(in); + if (!c->ValueKnown(known_in_elems)) { + known_in_elems = c->MakeDim(1); + for (int32 i = 0; i < c->Rank(in); ++i) { + DimensionHandle dim = c->Dim(in, i); + if (!c->ValueKnown(dim)) { + if (in_unknown_idx >= 0) { + too_many_unknown = true; + break; + } + in_unknown_idx = i; + } else { + TF_RETURN_IF_ERROR(c->Multiply(known_in_elems, dim, &known_in_elems)); + } + } } - } + if (!too_many_unknown) { + if (in_unknown_idx < 0 && out_unknown_idx < 0) { + // Just check that the dimensions match. + if (c->Value(known_in_elems) != c->Value(known_out_elems)) { + return errors::InvalidArgument( + "Cannot reshape a tensor with ", c->DebugString(known_in_elems), + " elements to shape ", c->DebugString(out), " (", + c->DebugString(known_out_elems), " elements)"); + } + } else if (in_unknown_idx < 0 && out_unknown_idx >= 0 && + c->Value(known_out_elems) > 0) { + // Input fully known, infer the one missing output dim + DimensionHandle inferred_dim; + TF_RETURN_IF_ERROR(c->Divide(known_in_elems, c->Value(known_out_elems), + true /* evenly_divisible */, + &inferred_dim)); + TF_RETURN_IF_ERROR( + c->ReplaceDim(out, out_unknown_idx, inferred_dim, &out)); + + } else if (in_unknown_idx >= 0 && out_unknown_idx < 0 && + c->Value(known_in_elems) != 0) { + // Output fully known, infer the one missing input dim + DimensionHandle inferred_dim; + TF_RETURN_IF_ERROR(c->Divide(known_out_elems, c->Value(known_in_elems), + true /* evenly_divisible */, + &inferred_dim)); + DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); + TF_RETURN_IF_ERROR( + c->Merge(unknown_in_dim, inferred_dim, &unknown_in_dim)); + } else if (in_unknown_idx >= 0 && out_unknown_idx >= 0) { + // Exactly one unknown dimension in both input and output. These 2 are + // equal iff the known elements are equal. + if (c->Value(known_in_elems) == c->Value(known_out_elems)) { + DimensionHandle unknown_in_dim = c->Dim(in, in_unknown_idx); + TF_RETURN_IF_ERROR( + c->ReplaceDim(out, out_unknown_idx, unknown_in_dim, &out)); + } + } + } + } c->set_output(0, out); return Status::OK(); } |