aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/shape_inference.cc
diff options
context:
space:
mode:
authorGravatar Benoit Steiner <bsteiner@google.com>2017-11-07 15:15:17 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-07 15:19:00 -0800
commit8c88be0da425ad686f76e03f21d2947dedea5123 (patch)
tree0d24133c1a1f99207e9d4c8fcae40bf531fcc3ab /tensorflow/core/framework/shape_inference.cc
parent1d60b4a67cfe2838565135312bf6d5e9b47d60a3 (diff)
Track merged shapes and dimensions more accurately.
PiperOrigin-RevId: 174920827
Diffstat (limited to 'tensorflow/core/framework/shape_inference.cc')
-rw-r--r--tensorflow/core/framework/shape_inference.cc26
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/core/framework/shape_inference.cc b/tensorflow/core/framework/shape_inference.cc
index ffa235d15c..5d6bf559bb 100644
--- a/tensorflow/core/framework/shape_inference.cc
+++ b/tensorflow/core/framework/shape_inference.cc
@@ -418,11 +418,16 @@ void InferenceContext::Relax(DimensionHandle d0, DimensionHandle d1,
Status InferenceContext::Merge(DimensionHandle d0, DimensionHandle d1,
DimensionHandle* out) {
- if (d0.SameHandle(d1) || !ValueKnown(d1)) {
+ if (d0.SameHandle(d1)) {
+ *out = d0;
+ return Status::OK();
+ } else if (!ValueKnown(d1)) {
*out = d0;
+ merged_dims_.emplace_back(d0, d1);
return Status::OK();
} else if (!ValueKnown(d0)) {
*out = d1;
+ merged_dims_.emplace_back(d0, d1);
return Status::OK();
} else if (Value(d0) == Value(d1)) {
*out = d0;
@@ -502,11 +507,16 @@ void InferenceContext::Relax(ShapeHandle s0, ShapeHandle s1, ShapeHandle* out) {
Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
ShapeHandle* out) {
- if (s0.SameHandle(s1) || !RankKnown(s1)) {
+ if (s0.SameHandle(s1)) {
+ *out = s0;
+ return Status::OK();
+ } else if (!RankKnown(s1)) {
*out = s0;
+ merged_shapes_.emplace_back(s0, s1);
return Status::OK();
} else if (!RankKnown(s0)) {
*out = s1;
+ merged_shapes_.emplace_back(s0, s1);
return Status::OK();
}
@@ -539,6 +549,9 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
Value(d0), " and ", Value(d1));
}
}
+
+ merged_shapes_.emplace_back(s0, s1);
+
if (return_s0 || return_s1) {
*out = return_s0 ? s0 : s1;
return Status::OK();
@@ -550,7 +563,14 @@ Status InferenceContext::Merge(ShapeHandle s0, ShapeHandle s1,
// Invariant for merge was checked earlier, so CHECK is ok.
TF_CHECK_OK(Merge(Dim(s0, i), Dim(s1, i), &dims[i]));
}
- return ReturnCreatedShape(dims, out);
+
+ Status s = ReturnCreatedShape(dims, out);
+ if (s.ok()) {
+ // Merge the new shape with s0. Since s0 and s1 are merged, this implies
+ // that s1 and out are also merged.
+ merged_shapes_.emplace_back(s0, *out);
+ }
+ return s;
}
Status InferenceContext::Subshape(ShapeHandle s, int64 start,