diff options
author | 2016-12-16 14:08:39 -0800 | |
---|---|---|
committer | 2016-12-16 14:32:35 -0800 | |
commit | 1b9116522931eb445c03d4a4653a21441a613f94 (patch) | |
tree | c15dbbec2bc4337d8c90aa80516d55d1e01d8358 | |
parent | d20ddfbe18b2b1c8f749d8b6ea4d2e60da25b3a3 (diff) |
Make switch not fetchable. This is to address a problem introduced in CL/142214059.
Change: 142296727
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 20 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 1 |
2 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index fd3dc9f44e..c1f3f34869 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -83,6 +83,16 @@ def check_consumers(graph): return True +def all_fetchables(): + tensor_names = [] + graph = ops.get_default_graph() + for op in graph.get_operations(): + for t in op.outputs: + if graph.is_fetchable(t): + tensor_names.append(t.name) + return tensor_names + + def opt_cfg(): return config_pb2.ConfigProto( allow_soft_placement=True, @@ -301,6 +311,16 @@ class ControlFlowTest(test.TestCase): with self.assertRaisesRegexp(TypeError, "must not be a Python bool"): _ = control_flow_ops.cond(False, fn1, fn2) + def testFetchables(self): + with self.test_session() as sess: + x = array_ops.placeholder(dtypes.float32) + control_flow_ops.cond(constant_op.constant(True), + lambda: x + 2, + lambda: x + 0) + tensor_names = all_fetchables() + for name in tensor_names: + sess.run(name, feed_dict={x: 3}) + def testCondIndexedSlices(self): with self.test_session(): values = constant_op.constant(10) diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 4a2c76979a..df125a8cc4 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1597,6 +1597,7 @@ class CondContext(ControlFlowContext): self._values.add(result.name) with ops.control_dependencies(None): result = _SwitchRefOrTensor(result, self._pred)[self._branch] + result.op.graph.prevent_fetching(result.op) # pylint: disable=protected-access result.op._set_control_flow_context(self) # pylint: enable=protected-access |