aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/control_flow_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops.py')
-rw-r--r--tensorflow/python/ops/control_flow_ops.py57
1 files changed, 44 insertions, 13 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index fc37805c79..aeac61c005 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1817,15 +1817,34 @@ class CondContext(ControlFlowContext):
def _AddOpInternal(self, op):
"""Add `op` to the current context."""
if not op.inputs:
- # Remove any external control dependency on this op
+ # If we're in a while loop, remove any control inputs from outside the
+ # loop.
self._RemoveExternalControlEdges(op)
- # pylint: disable=protected-access
- op._add_control_input(self._pivot.op)
- # pylint: enable=protected-access
+
+ if not any(util.OpInContext(input_op, self)
+ for input_op in op.control_inputs):
+ # pylint: disable=protected-access
+ op._add_control_input(self._pivot.op)
+ # pylint: enable=protected-access
else:
+ # Make each input to 'op' available in this CondContext. If an input is
+ # already part of this context there's nothing to do, but if it's
+ # external, AddValue() will handle adding the appropriate Switch node and
+ # other bookkeeping.
for index in range(len(op.inputs)):
x = op.inputs[index]
- real_x = self.AddValue(x)
+ if op.type == "Merge" and x.op.type == "NextIteration":
+ # Edge case: if we're importing a while loop inside this CondContext,
+ # AddValue() will not correctly handle the NextIteration inputs to
+ # Merge node. The problem is that the NextIteration should also be
+ # part of this context, but if we're importing it won't have been
+ # processed and added to the context yet, so AddValue() will try to
+ # add a Switch which results in an invalid graph. Instead, we use the
+ # NextIteration input as-is here, and it will eventually be added to
+ # the context via AddOp().
+ real_x = x
+ else:
+ real_x = self.AddValue(x)
if real_x != x:
# pylint: disable=protected-access
op._update_input(index, real_x)
@@ -2932,7 +2951,8 @@ class WhileContext(ControlFlowContext):
return original_body_result, exit_vars
- def BuildLoop(self, pred, body, loop_vars, shape_invariants):
+ def BuildLoop(self, pred, body, loop_vars, shape_invariants,
+ return_same_structure):
"""Add the loop termination condition and body to the graph."""
# Keep original_loop_vars to identify which are TensorArrays
@@ -2960,7 +2980,11 @@ class WhileContext(ControlFlowContext):
packed_exit_vars = nest.pack_sequence_as(
structure=original_body_result,
flat_sequence=exit_vars_with_tensor_arrays)
- return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
+
+ if return_same_structure:
+ return packed_exit_vars
+ else:
+ return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars
def _FixControlInputsAndContext(self, enters):
graph = ops.get_default_graph()
@@ -3000,7 +3024,8 @@ def while_loop(cond,
back_prop=True,
swap_memory=False,
name=None,
- maximum_iterations=None):
+ maximum_iterations=None,
+ return_same_structure=False):
"""Repeat `body` while the condition `cond` is true.
`cond` is a callable returning a boolean scalar tensor. `body` is a callable
@@ -3076,11 +3101,16 @@ def while_loop(cond,
to run. If provided, the `cond` output is AND-ed with an additional
condition ensuring the number of iterations executed is no greater than
`maximum_iterations`.
+ return_same_structure: If True, output has same structure as `loop_vars`. If
+ eager execution is enabled, this is ignored (and always treated as True).
Returns:
- The output tensors for the loop variables after the loop. When the length
- of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when
- the length of `loop_vars` is greater than 1 it returns a list.
+ The output tensors for the loop variables after the loop.
+ If `return_same_structure` is True, the return value has the same
+ structure as `loop_vars`.
+ If `return_same_structure` is False, the return value is a Tensor,
+ TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list
+ otherwise.
Raises:
TypeError: if `cond` or `body` is not callable.
@@ -3135,7 +3165,7 @@ def while_loop(cond,
happen is that the thread updating `x` can never get ahead of the
counter thread because the thread incrementing `x` depends on the value
of the counter.
-
+
```python
import tensorflow as tf
@@ -3217,7 +3247,8 @@ def while_loop(cond,
# be encapsulated in the root context.
if loop_context.outer_context is None:
ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context)
- result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants)
+ result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants,
+ return_same_structure)
if maximum_iterations is not None:
return result[1]
else: