aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2015-12-15 17:34:31 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-15 17:34:31 -0800
commitb8a654f3ccf232375b8b27408d71c5ecfa70ebe8 (patch)
tree243acb675cf8af9f991c8e635594674a1602283f
parent881dc225ecb32064681c7bf2229d796565ad7956 (diff)
Fixed a bug in cond when a branch passes a value through mulitple times, i.e., does something like "lambda: [x, x]".
Change: 110306609
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py12
-rw-r--r--tensorflow/python/ops/control_flow_ops.py12
2 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index a51a6f9137..3e4874b5cc 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -445,6 +445,17 @@ class ControlFlowTest(tf.test.TestCase):
result = r.eval()
self.assertAllEqual(np.array([7]), result)
+ def testCond_7(self):
+ with self.test_session() as sess:
+ x = tf.constant(10)
+ y = tf.constant(200)
+ pred = tf.less(1, 2)
+ fn1 = lambda: [tf.add(x, 1), tf.add(x, 2)]
+ fn2 = lambda: [y, y]
+ r = control_flow_ops.cond(pred, fn1, fn2)
+
+ self.assertAllEqual([11, 12], sess.run(r))
+
def testCondGrad_1(self):
with self.test_session():
x = tf.constant(10.0, name="x")
@@ -1365,5 +1376,6 @@ class TupleTest(tf.test.TestCase):
self.assertEquals(1, var.eval())
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 7bcbf1c0cf..85c681b7b8 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -571,18 +571,20 @@ class CondContext(ControlFlowContext):
if not isinstance(r, list) and not isinstance(r, _basetuple):
r = [r]
for v in r:
+ real_v = v
if isinstance(v, ops.Operation):
- v = with_dependencies([v], self._pivot)
+ real_v = with_dependencies([v], self._pivot)
elif v.name not in self._values:
self._values.add(v.name)
if self._outer_context is not None:
- v = self._outer_context.AddValue(v)
- v = _SwitchRefOrTensor(v, self._pred)[self._branch]
+ real_v = self._outer_context.AddValue(v)
+ real_v = _SwitchRefOrTensor(real_v, self._pred)[self._branch]
+ self._external_values[v.name] = real_v
else:
external_v = self._external_values.get(v.name)
if external_v is not None:
- v = external_v
- result.append(v)
+ real_v = external_v
+ result.append(real_v)
return result