aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-05-11 18:59:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-11 20:02:07 -0700
commitd03631a27a0a3bc3eb23faa630b60c8e9826e1be (patch)
treea4f0680910dea774c392e57e51b6c5ecf9403fc0
parentd388992ccfd6c8d750f0c7d67189bd4f2b6490de (diff)
Fix a bug in shape inference of while loop.
Change: 122120526
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py13
-rw-r--r--tensorflow/python/ops/control_flow_ops.py4
2 files changed, 15 insertions, 2 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index e587ccb55f..f9dd82b331 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -590,6 +590,19 @@ class ControlFlowTest(tf.test.TestCase):
self._testWhile_Gpu_1(use_gpu=False)
self._testWhile_Gpu_1(use_gpu=True)
+ def testWhileShape(self):
+ with self.test_session():
+ i = tf.constant(0)
+ m = tf.ones([2, 2])
+ c = lambda i, j: tf.less(i, 2)
+ def _b(i, j):
+ new_i = tf.add(i, 1)
+ new_j = tf.tile(j, [2, 2])
+ return [new_i, new_j]
+ r = tf.while_loop(c, _b, [i, m])
+ r = r[1] * tf.ones([8, 8])
+ self.assertAllEqual(np.ones((8, 8)), r.eval())
+
def _testNestedWhile_1(self, use_gpu):
with self.test_session(use_gpu=use_gpu):
n = tf.constant(0)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 2b7fee4efb..9d55acd1fc 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1588,8 +1588,8 @@ class WhileContext(ControlFlowContext):
self._loop_exits = exit_vars
for m_var, n_var, e_var in zip(merge_vars, next_vars, exit_vars):
- if m_var.get_shape().is_compatible_with(n_var.get_shape()):
- e_var.set_shape(m_var.get_shape().merge_with(n_var.get_shape()))
+ if not m_var.get_shape() == n_var.get_shape():
+ e_var._shape = tensor_shape.unknown_shape()
# Exit the loop.
self.ExitResult(exit_vars)