diff options
author | 2016-05-11 18:59:21 -0800 | |
---|---|---|
committer | 2016-05-11 20:02:07 -0700 | |
commit | d03631a27a0a3bc3eb23faa630b60c8e9826e1be (patch) | |
tree | a4f0680910dea774c392e57e51b6c5ecf9403fc0 | |
parent | d388992ccfd6c8d750f0c7d67189bd4f2b6490de (diff) |
Fix a bug in shape inference of while loop.
Change: 122120526
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 13 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 4 |
2 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index e587ccb55f..f9dd82b331 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -590,6 +590,19 @@ class ControlFlowTest(tf.test.TestCase): self._testWhile_Gpu_1(use_gpu=False) self._testWhile_Gpu_1(use_gpu=True) + def testWhileShape(self): + with self.test_session(): + i = tf.constant(0) + m = tf.ones([2, 2]) + c = lambda i, j: tf.less(i, 2) + def _b(i, j): + new_i = tf.add(i, 1) + new_j = tf.tile(j, [2, 2]) + return [new_i, new_j] + r = tf.while_loop(c, _b, [i, m]) + r = r[1] * tf.ones([8, 8]) + self.assertAllEqual(np.ones((8, 8)), r.eval()) + def _testNestedWhile_1(self, use_gpu): with self.test_session(use_gpu=use_gpu): n = tf.constant(0) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 2b7fee4efb..9d55acd1fc 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1588,8 +1588,8 @@ class WhileContext(ControlFlowContext): self._loop_exits = exit_vars for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars): - if m_var.get_shape().is_compatible_with(n_var.get_shape()): - e_var.set_shape(m_var.get_shape().merge_with(n_var.get_shape())) + if not m_var.get_shape() == n_var.get_shape(): + e_var._shape = tensor_shape.unknown_shape() # Exit the loop. self.ExitResult(exit_vars) |