diff options
author | 2015-12-15 17:34:31 -0800 | |
---|---|---|
committer | 2015-12-15 17:34:31 -0800 | |
commit | b8a654f3ccf232375b8b27408d71c5ecfa70ebe8 (patch) | |
tree | 243acb675cf8af9f991c8e635594674a1602283f | |
parent | 881dc225ecb32064681c7bf2229d796565ad7956 (diff) |
Fixed a bug in cond when a branch passes a value through mulitple times, i.e., does something like "lambda: [x, x]".
Change: 110306609
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 12 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 12 |
2 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index a51a6f9137..3e4874b5cc 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -445,6 +445,17 @@ class ControlFlowTest(tf.test.TestCase): result = r.eval() self.assertAllEqual(np.array([7]), result) + def testCond_7(self): + with self.test_session() as sess: + x = tf.constant(10) + y = tf.constant(200) + pred = tf.less(1, 2) + fn1 = lambda: [tf.add(x, 1), tf.add(x, 2)] + fn2 = lambda: [y, y] + r = control_flow_ops.cond(pred, fn1, fn2) + + self.assertAllEqual([11, 12], sess.run(r)) + def testCondGrad_1(self): with self.test_session(): x = tf.constant(10.0, name="x") @@ -1365,5 +1376,6 @@ class TupleTest(tf.test.TestCase): self.assertEquals(1, var.eval()) + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 7bcbf1c0cf..85c681b7b8 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -571,18 +571,20 @@ class CondContext(ControlFlowContext): if not isinstance(r, list) and not isinstance(r, _basetuple): r = [r] for v in r: + real_v = v if isinstance(v, ops.Operation): - v = with_dependencies([v], self._pivot) + real_v = with_dependencies([v], self._pivot) elif v.name not in self._values: self._values.add(v.name) if self._outer_context is not None: - v = self._outer_context.AddValue(v) - v = _SwitchRefOrTensor(v, self._pred)[self._branch] + real_v = self._outer_context.AddValue(v) + real_v = _SwitchRefOrTensor(real_v, self._pred)[self._branch] + self._external_values[v.name] = real_v else: external_v = self._external_values.get(v.name) if external_v is not None: - v = external_v - result.append(v) + real_v = external_v + result.append(real_v) return result |