diff options
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 9 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 16 |
2 files changed, 13 insertions, 12 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index e27eb00818..209411cf51 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1135,11 +1135,10 @@ class ControlFlowTest(test.TestCase): with self.assertRaisesRegexp( ValueError, - r"The shape for while_1/Merge_1:0 is not an invariant for the loop. " - r"It enters the loop with shape \(2, 2\), but has shape \(4, 2\) " - r"after one iteration. Provide shape invariants using either the " - r"`shape_invariants` argument of tf.while_loop or set_shape\(\) on " - r"the loop variables."): + r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has " + r"shape \(4, 2\) after one iteration. To allow the shape to vary " + r"across iterations, use the `shape_invariants` argument of " + r"tf.while_loop to specify a less-specific shape."): r = control_flow_ops.while_loop(c, b, [i, m]) def testWhileShapeInferenceSparseTensor(self): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index a1bfe450c8..f1e068d514 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -622,14 +622,16 @@ def _EnforceShapeInvariant(merge_var, next_var): m_shape = merge_var.get_shape() n_shape = next_var.get_shape() if not _ShapeLessThanOrEqual(n_shape, m_shape): - # TODO(skyewm): get original loop input that caused the shape error and - # report its name instead of the merge node's. + enter = merge_var.op.inputs[0].op + assert util.IsLoopEnter(enter) + input_t = enter.inputs[0] + assert input_t.shape == m_shape raise ValueError( - "The shape for %s is not an invariant for the loop. It enters " - "the loop with shape %s, but has shape %s after one iteration. " - "Provide shape invariants using either the `shape_invariants` " - "argument of tf.while_loop or set_shape() on the loop variables." % - (merge_var.name, m_shape, n_shape)) + "Input tensor '%s' enters the loop with shape %s, but has shape %s " + "after one iteration. To allow the shape to vary across iterations, " + "use the `shape_invariants` argument of tf.while_loop to specify a " + "less-specific shape." % + (input_t.name, input_t.shape, n_shape)) else: if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)): raise TypeError("Type %s not supported" % type(var)) |