aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2016-06-01 15:04:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-01 16:17:25 -0700
commitf134f05ff95bbcb6ab205b67ceff1743126d9465 (patch)
tree667fb90a8b28e5e846894f6ae64da5c90cd84d0f /tensorflow/python
parent3f07cb2b18b7ce45773d467a52336c35ff5b91dd (diff)
Prevents the fetching of internal nodes defined in a while loop or cond branch.
Fixes #2575. Change: 123805910
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/client/session.py7
-rw-r--r--tensorflow/python/framework/ops.py13
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py16
-rw-r--r--tensorflow/python/ops/control_flow_ops.py4
4 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py
index 943fb70b13..5e07672b2e 100644
--- a/tensorflow/python/client/session.py
+++ b/tensorflow/python/client/session.py
@@ -492,6 +492,11 @@ class BaseSession(SessionInterface):
return self._do_call(_setup_fn, self._session, feed_list, unique_fetches,
target_list)
+ def _assert_fetchable(self, op):
+ if not self.graph.is_fetchable(op):
+ raise ValueError(
+ 'Operation %r has been marked as not fetchable.' % op.name)
+
def _process_fetches(self, fetches):
"""Validate and process fetches."""
def _fetch_fn(fetch):
@@ -520,8 +525,10 @@ class BaseSession(SessionInterface):
allow_operation=True)
fetch_name = compat.as_bytes(fetch_t.name)
if isinstance(fetch_t, ops.Operation):
+ self._assert_fetchable(fetch_t)
target_list.append(fetch_name)
else:
+ self._assert_fetchable(fetch_t.op)
subfetch_names.append(fetch_name)
# Remember the fetch if it is for a tensor handle.
if (isinstance(fetch_t, ops.Tensor) and
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 257fe8cf97..95d9507aa2 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -1970,6 +1970,8 @@ class Graph(object):
self._colocation_stack = []
# Set of tensors that are dangerous to feed!
self._unfeedable_tensors = set()
+ # Set of operations that are dangerous to fetch!
+ self._unfetchable_ops = set()
# A map of tensor handle placeholder to tensor dtype.
self._handle_feeders = {}
# A map from tensor handle to its read op.
@@ -3267,6 +3269,17 @@ class Graph(object):
"""Returns `True` if and only if `tensor` is feedable."""
return tensor not in self._unfeedable_tensors
+ def prevent_fetching(self, op):
+ """Marks the given `op` as unfetchable in this graph."""
+ self._unfetchable_ops.add(op)
+
+ def is_fetchable(self, tensor_or_op):
+ """Returns `True` if and only if `tensor_or_op` is fetchable."""
+ if isinstance(tensor_or_op, Tensor):
+ return tensor_or_op.op not in self._unfetchable_ops
+ else:
+ return tensor_or_op not in self._unfetchable_ops
+
def device(device_name_or_function):
"""Wrapper for `Graph.device()` using the default graph.
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 6921ab2aa6..93ffc60394 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1606,6 +1606,22 @@ class ControlFlowTest(tf.test.TestCase):
s = control_flow_ops.ref_select(index, [p1, p2])
self.assertEqual(None, s.get_shape())
+ def testRunLoopTensor(self):
+ with self.test_session() as sess:
+ tensor_list = []
+ def condition(t):
+ return t < tf.constant(5)
+ def body(_):
+ tensor_list.append(tf.constant(5))
+ return tf.constant(10)
+ result = tf.while_loop(condition, body, [tf.constant(4)])
+ self.assertEqual(10, sess.run(result))
+
+ # Ensure that we cannot run a tensor that escapes the loop body
+ # accidentally.
+ with self.assertRaises(ValueError):
+ sess.run(tensor_list[0])
+
class TupleTest(tf.test.TestCase):
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index e00fec07c1..2bb89a89cf 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1222,6 +1222,8 @@ class CondContext(ControlFlowContext):
op._update_input(index, x)
for x in op.outputs:
self._values.add(x.name)
+ if self._outer_context or op.type not in {"Exit", "RefExit"}:
+ op.graph.prevent_fetching(op)
def BuildCondBranch(self, fn):
"""Add the subgraph defined by fn() to the graph."""
@@ -1503,6 +1505,8 @@ class WhileContext(ControlFlowContext):
self._MaybeAddControlDependency(op)
for x in op.outputs:
self._values.add(x.name)
+ if self._outer_context or op.type not in {"Exit", "RefExit"}:
+ op.graph.prevent_fetching(op)
def _MaybeAddControlDependency(self, op):
"""Add a control input to the op if it only depends on loop invariants."""