aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuan Yu <yuanbyu@google.com>2016-12-16 14:08:39 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-16 14:32:35 -0800
commit1b9116522931eb445c03d4a4653a21441a613f94 (patch)
treec15dbbec2bc4337d8c90aa80516d55d1e01d8358
parentd20ddfbe18b2b1c8f749d8b6ea4d2e60da25b3a3 (diff)
Make switch not fetchable. This is to address a problem introduced in CL/142214059.
Change: 142296727
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py20
-rw-r--r--tensorflow/python/ops/control_flow_ops.py1
2 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index fd3dc9f44e..c1f3f34869 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -83,6 +83,16 @@ def check_consumers(graph):
return True
+def all_fetchables():
+ tensor_names = []
+ graph = ops.get_default_graph()
+ for op in graph.get_operations():
+ for t in op.outputs:
+ if graph.is_fetchable(t):
+ tensor_names.append(t.name)
+ return tensor_names
+
+
def opt_cfg():
return config_pb2.ConfigProto(
allow_soft_placement=True,
@@ -301,6 +311,16 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, "must not be a Python bool"):
_ = control_flow_ops.cond(False, fn1, fn2)
+ def testFetchables(self):
+ with self.test_session() as sess:
+ x = array_ops.placeholder(dtypes.float32)
+ control_flow_ops.cond(constant_op.constant(True),
+ lambda: x + 2,
+ lambda: x + 0)
+ tensor_names = all_fetchables()
+ for name in tensor_names:
+ sess.run(name, feed_dict={x: 3})
+
def testCondIndexedSlices(self):
with self.test_session():
values = constant_op.constant(10)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 4a2c76979a..df125a8cc4 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1597,6 +1597,7 @@ class CondContext(ControlFlowContext):
self._values.add(result.name)
with ops.control_dependencies(None):
result = _SwitchRefOrTensor(result, self._pred)[self._branch]
+ result.op.graph.prevent_fetching(result.op)
# pylint: disable=protected-access
result.op._set_control_flow_context(self)
# pylint: enable=protected-access