diff options
author | 2016-06-01 15:04:34 -0800 | |
---|---|---|
committer | 2016-06-01 16:17:25 -0700 | |
commit | f134f05ff95bbcb6ab205b67ceff1743126d9465 (patch) | |
tree | 667fb90a8b28e5e846894f6ae64da5c90cd84d0f /tensorflow/python | |
parent | 3f07cb2b18b7ce45773d467a52336c35ff5b91dd (diff) |
Prevents the fetching of internal nodes defined in a while loop or cond branch.
Fixes #2575.
Change: 123805910
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/client/session.py | 7 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 13 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/control_flow_ops_py_test.py | 16 | ||||
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 4 |
4 files changed, 40 insertions, 0 deletions
diff --git a/tensorflow/python/client/session.py b/tensorflow/python/client/session.py index 943fb70b13..5e07672b2e 100644 --- a/tensorflow/python/client/session.py +++ b/tensorflow/python/client/session.py @@ -492,6 +492,11 @@ class BaseSession(SessionInterface): return self._do_call(_setup_fn, self._session, feed_list, unique_fetches, target_list) + def _assert_fetchable(self, op): + if not self.graph.is_fetchable(op): + raise ValueError( + 'Operation %r has been marked as not fetchable.' % op.name) + def _process_fetches(self, fetches): """Validate and process fetches.""" def _fetch_fn(fetch): @@ -520,8 +525,10 @@ class BaseSession(SessionInterface): allow_operation=True) fetch_name = compat.as_bytes(fetch_t.name) if isinstance(fetch_t, ops.Operation): + self._assert_fetchable(fetch_t) target_list.append(fetch_name) else: + self._assert_fetchable(fetch_t.op) subfetch_names.append(fetch_name) # Remember the fetch if it is for a tensor handle. if (isinstance(fetch_t, ops.Tensor) and diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 257fe8cf97..95d9507aa2 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -1970,6 +1970,8 @@ class Graph(object): self._colocation_stack = [] # Set of tensors that are dangerous to feed! self._unfeedable_tensors = set() + # Set of operations that are dangerous to fetch! + self._unfetchable_ops = set() # A map of tensor handle placeholder to tensor dtype. self._handle_feeders = {} # A map from tensor handle to its read op. @@ -3267,6 +3269,17 @@ class Graph(object): """Returns `True` if and only if `tensor` is feedable.""" return tensor not in self._unfeedable_tensors + def prevent_fetching(self, op): + """Marks the given `op` as unfetchable in this graph.""" + self._unfetchable_ops.add(op) + + def is_fetchable(self, tensor_or_op): + """Returns `True` if and only if `tensor_or_op` is fetchable.""" + if isinstance(tensor_or_op, Tensor): + return tensor_or_op.op not in self._unfetchable_ops + else: + return tensor_or_op not in self._unfetchable_ops + def device(device_name_or_function): """Wrapper for `Graph.device()` using the default graph. diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 6921ab2aa6..93ffc60394 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -1606,6 +1606,22 @@ class ControlFlowTest(tf.test.TestCase): s = control_flow_ops.ref_select(index, [p1, p2]) self.assertEqual(None, s.get_shape()) + def testRunLoopTensor(self): + with self.test_session() as sess: + tensor_list = [] + def condition(t): + return t < tf.constant(5) + def body(_): + tensor_list.append(tf.constant(5)) + return tf.constant(10) + result = tf.while_loop(condition, body, [tf.constant(4)]) + self.assertEqual(10, sess.run(result)) + + # Ensure that we cannot run a tensor that escapes the loop body + # accidentally. + with self.assertRaises(ValueError): + sess.run(tensor_list[0]) + class TupleTest(tf.test.TestCase): diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index e00fec07c1..2bb89a89cf 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1222,6 +1222,8 @@ class CondContext(ControlFlowContext): op._update_input(index, x) for x in op.outputs: self._values.add(x.name) + if self._outer_context or op.type not in {"Exit", "RefExit"}: + op.graph.prevent_fetching(op) def BuildCondBranch(self, fn): """Add the subgraph defined by fn() to the graph.""" @@ -1503,6 +1505,8 @@ class WhileContext(ControlFlowContext): self._MaybeAddControlDependency(op) for x in op.outputs: self._values.add(x.name) + if self._outer_context or op.type not in {"Exit", "RefExit"}: + op.graph.prevent_fetching(op) def _MaybeAddControlDependency(self, op): """Add a control input to the op if it only depends on loop invariants.""" |