diff options
Diffstat (limited to 'tensorflow/python/ops/control_flow_ops.py')
-rw-r--r-- | tensorflow/python/ops/control_flow_ops.py | 57 |
1 files changed, 44 insertions, 13 deletions
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index fc37805c79..aeac61c005 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1817,15 +1817,34 @@ class CondContext(ControlFlowContext): def _AddOpInternal(self, op): """Add `op` to the current context.""" if not op.inputs: - # Remove any external control dependency on this op + # If we're in a while loop, remove any control inputs from outside the + # loop. self._RemoveExternalControlEdges(op) - # pylint: disable=protected-access - op._add_control_input(self._pivot.op) - # pylint: enable=protected-access + + if not any(util.OpInContext(input_op, self) + for input_op in op.control_inputs): + # pylint: disable=protected-access + op._add_control_input(self._pivot.op) + # pylint: enable=protected-access else: + # Make each input to 'op' available in this CondContext. If an input is + # already part of this context there's nothing to do, but if it's + # external, AddValue() will handle adding the appropriate Switch node and + # other bookkeeping. for index in range(len(op.inputs)): x = op.inputs[index] - real_x = self.AddValue(x) + if op.type == "Merge" and x.op.type == "NextIteration": + # Edge case: if we're importing a while loop inside this CondContext, + # AddValue() will not correctly handle the NextIteration inputs to + # Merge node. The problem is that the NextIteration should also be + # part of this context, but if we're importing it won't have been + # processed and added to the context yet, so AddValue() will try to + # add a Switch which results in an invalid graph. Instead, we use the + # NextIteration input as-is here, and it will eventually be added to + # the context via AddOp(). + real_x = x + else: + real_x = self.AddValue(x) if real_x != x: # pylint: disable=protected-access op._update_input(index, real_x) @@ -2932,7 +2951,8 @@ class WhileContext(ControlFlowContext): return original_body_result, exit_vars - def BuildLoop(self, pred, body, loop_vars, shape_invariants): + def BuildLoop(self, pred, body, loop_vars, shape_invariants, + return_same_structure): """Add the loop termination condition and body to the graph.""" # Keep original_loop_vars to identify which are TensorArrays @@ -2960,7 +2980,11 @@ class WhileContext(ControlFlowContext): packed_exit_vars = nest.pack_sequence_as( structure=original_body_result, flat_sequence=exit_vars_with_tensor_arrays) - return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars + + if return_same_structure: + return packed_exit_vars + else: + return packed_exit_vars[0] if len(exit_vars) == 1 else packed_exit_vars def _FixControlInputsAndContext(self, enters): graph = ops.get_default_graph() @@ -3000,7 +3024,8 @@ def while_loop(cond, back_prop=True, swap_memory=False, name=None, - maximum_iterations=None): + maximum_iterations=None, + return_same_structure=False): """Repeat `body` while the condition `cond` is true. `cond` is a callable returning a boolean scalar tensor. `body` is a callable @@ -3076,11 +3101,16 @@ def while_loop(cond, to run. If provided, the `cond` output is AND-ed with an additional condition ensuring the number of iterations executed is no greater than `maximum_iterations`. + return_same_structure: If True, output has same structure as `loop_vars`. If + eager execution is enabled, this is ignored (and always treated as True). Returns: - The output tensors for the loop variables after the loop. When the length - of `loop_vars` is 1 this is a Tensor, TensorArray or IndexedSlice and when - the length of `loop_vars` is greater than 1 it returns a list. + The output tensors for the loop variables after the loop. + If `return_same_structure` is True, the return value has the same + structure as `loop_vars`. + If `return_same_structure` is False, the return value is a Tensor, + TensorArray or IndexedSlice if the length of `loop_vars` is 1, or a list + otherwise. Raises: TypeError: if `cond` or `body` is not callable. @@ -3135,7 +3165,7 @@ def while_loop(cond, happen is that the thread updating `x` can never get ahead of the counter thread because the thread incrementing `x` depends on the value of the counter. - + ```python import tensorflow as tf @@ -3217,7 +3247,8 @@ def while_loop(cond, # be encapsulated in the root context. if loop_context.outer_context is None: ops.add_to_collection(ops.GraphKeys.WHILE_CONTEXT, loop_context) - result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants) + result = loop_context.BuildLoop(cond, body, loop_vars, shape_invariants, + return_same_structure) if maximum_iterations is not None: return result[1] else: |