aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py9
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
2 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index e27eb00818..209411cf51 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1135,11 +1135,10 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
- r"The shape for while_1/Merge_1:0 is not an invariant for the loop. "
- r"It enters the loop with shape \(2, 2\), but has shape \(4, 2\) "
- r"after one iteration. Provide shape invariants using either the "
- r"`shape_invariants` argument of tf.while_loop or set_shape\(\) on "
- r"the loop variables."):
+ r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
+ r"shape \(4, 2\) after one iteration. To allow the shape to vary "
+ r"across iterations, use the `shape_invariants` argument of "
+ r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
def testWhileShapeInferenceSparseTensor(self):
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index a1bfe450c8..f1e068d514 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -622,14 +622,16 @@ def _EnforceShapeInvariant(merge_var, next_var):
m_shape = merge_var.get_shape()
n_shape = next_var.get_shape()
if not _ShapeLessThanOrEqual(n_shape, m_shape):
- # TODO(skyewm): get original loop input that caused the shape error and
- # report its name instead of the merge node's.
+ enter = merge_var.op.inputs[0].op
+ assert util.IsLoopEnter(enter)
+ input_t = enter.inputs[0]
+ assert input_t.shape == m_shape
raise ValueError(
- "The shape for %s is not an invariant for the loop. It enters "
- "the loop with shape %s, but has shape %s after one iteration. "
- "Provide shape invariants using either the `shape_invariants` "
- "argument of tf.while_loop or set_shape() on the loop variables." %
- (merge_var.name, m_shape, n_shape))
+ "Input tensor '%s' enters the loop with shape %s, but has shape %s "
+ "after one iteration. To allow the shape to vary across iterations, "
+ "use the `shape_invariants` argument of tf.while_loop to specify a "
+ "less-specific shape." %
+ (input_t.name, input_t.shape, n_shape))
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(var))