diff options
author | Benoit Steiner <bsteiner@google.com> | 2017-11-07 15:15:17 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-07 15:19:00 -0800 |
commit | 8c88be0da425ad686f76e03f21d2947dedea5123 (patch) | |
tree | 0d24133c1a1f99207e9d4c8fcae40bf531fcc3ab /tensorflow/core/framework/shape_inference.cc | |
parent | 1d60b4a67cfe2838565135312bf6d5e9b47d60a3 (diff) |
Track merged shapes and dimensions more accurately.
PiperOrigin-RevId: 174920827
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r-- | tensorflow/core/framework/shape_inference.cc | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc index ffa235d15c..5d6bf559bb 100644 --- a/tensorflow/core/framework/shape_inference.cc +++ b/tensorflow/core/framework/shape_inference.cc @@ -418,11 +418,16 @@ void InferenceContext::Relax(DimensionHandle d0, DimensionHandle d1, Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1, DimensionHandle* out) { - if (d0.SameHandle(d1) || !ValueKnown(d1)) { + if (d0.SameHandle(d1)) { + *out = d0; + return Status::OK(); + } else if (!ValueKnown(d1)) { *out = d0; + merged_dims_.emplace_back(d0, d1); return Status::OK(); } else if (!ValueKnown(d0)) { *out = d1; + merged_dims_.emplace_back(d0, d1); return Status::OK(); } else if (Value(d0) == Value(d1)) { *out = d0; @@ -502,11 +507,16 @@ void InferenceContext::Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) { - if (s0.SameHandle(s1) || !RankKnown(s1)) { + if (s0.SameHandle(s1)) { + *out = s0; + return Status::OK(); + } else if (!RankKnown(s1)) { *out = s0; + merged_shapes_.emplace_back(s0, s1); return Status::OK(); } else if (!RankKnown(s0)) { *out = s1; + merged_shapes_.emplace_back(s0, s1); return Status::OK(); } @@ -539,6 +549,9 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, Value(d0), " and ", Value(d1)); } } + + merged_shapes_.emplace_back(s0, s1); + if (return_s0 || return_s1) { *out = return_s0 ? s0 : s1; return Status::OK(); @@ -550,7 +563,14 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1, // Invariant for merge was checked earlier, so CHECK is ok. TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i])); } - return ReturnCreatedShape(dims, out); + + Status s = ReturnCreatedShape(dims, out); + if (s.ok()) { + // Merge the new shape with s0. Since s0 and s1 are merged, this implies + // that s1 and out are also merged. + merged_shapes_.emplace_back(s0, *out); + } + return s; } Status InferenceContext::Subshape(ShapeHandle s, int64 start, |