From 3374d3a6e5b9f77fa4229c41b233f1c0a229216f Mon Sep 17 00:00:00 2001 From: Alexandre Passos Date: Mon, 5 Feb 2018 14:47:23 -0800 Subject: Automated g4 rollback of changelist 184573795 PiperOrigin-RevId: 184590080 --- .../kernel_tests/control_flow_ops_py_test.py | 30 ---------------------- tensorflow/python/ops/control_flow_ops.py | 25 +++++++----------- 2 files changed, 9 insertions(+), 46 deletions(-) diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py index 15ff0ec09b..4fafc36014 100644 --- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py +++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py @@ -704,36 +704,6 @@ class ControlFlowTest(test.TestCase): r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20) self.assertEqual(10000, r.eval()) - def testWhileExternalControlDependencies(self): - with self.test_session(): - v = variables.Variable(0.0) - v.initializer.run() - increment = v.assign_add(1.0) - - def body_fn(i): - with ops.control_dependencies([increment]): - return i + i - - result = control_flow_ops.while_loop(cond=lambda i: i < 1, - body=body_fn, loop_vars=[1]) - result.eval() - self.assertAllEqual(v.eval(), 1.0) - - def testWhileExternalControlDependenciesNoInput(self): - with self.test_session(): - v = variables.Variable(0.0) - v.initializer.run() - increment = v.assign_add(1.0) - - def body_fn(unused_i): - with ops.control_dependencies([increment]): - return constant_op.constant(5, name="five") - - result = control_flow_ops.while_loop(cond=lambda i: i < 5, - body=body_fn, loop_vars=[0]) - result.eval() - self.assertAllEqual(v.eval(), 1.0) - def testWhileWithRefs_1(self): with self.test_session() as sess: x = variables.Variable(0)._ref() # pylint: disable=protected-access diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py index 87ff0aba0a..bcd187d821 100644 --- a/tensorflow/python/ops/control_flow_ops.py +++ b/tensorflow/python/ops/control_flow_ops.py @@ -1631,13 +1631,10 @@ class ControlFlowContext(object): ctxt = util.GetOutputContext(x) if ctxt is not None and ctxt.GetWhileContext() == while_ctxt: internal_control_inputs.append(x) - external_control_inputs = [] if len(internal_control_inputs) != len(op.control_inputs): - external_control_inputs = list(set(op.control_inputs) - - set(internal_control_inputs)) op._remove_all_control_inputs() op._add_control_inputs(internal_control_inputs) - return internal_control_inputs, external_control_inputs + return internal_control_inputs # pylint: enable=protected-access @@ -2435,12 +2432,14 @@ class WhileContext(ControlFlowContext): def _AddOpInternal(self, op): """Add `op` to the current context. - We move any external control dependencies of the op to the loop pivot, to - ensure they get executed. + In the case that op has only external data inputs, we remove all of its + external control inputs so all its inputs are in the same while loop + context. This is valid because op now has an Enter input that has all + the right control dependency. """ if not op.inputs: # Remove any external control dependency on this op - control_inputs, external_inputs = self._RemoveExternalControlEdges(op) + control_inputs = self._RemoveExternalControlEdges(op) # Add a control edge from the control pivot to this op. if not control_inputs: # pylint: disable=protected-access @@ -2453,20 +2452,14 @@ class WhileContext(ControlFlowContext): x = op.inputs[index] real_x = self.AddValue(x) if real_x != x: - op._update_input(index, real_x) # pylint: disable=protected-access - # Remove any external control dependency on this op and move then to an - # Enter node. - _, external_inputs = self._RemoveExternalControlEdges(op) + op._update_input(index, real_x) + # Remove any external control dependency on this op. + self._RemoveExternalControlEdges(op) # Add a control dependency to prevent loop invariants from # enabling ops that should not be executed. self._MaybeAddControlDependency(op) for x in op.outputs: self._values.add(x.name) - if external_inputs: - # Make the pivot depend on external control inputs - pred = self._pivot_for_pred.op.inputs[0] - assert util.IsLoopEnter(pred.op) - pred.op._add_control_inputs(external_inputs) # pylint: disable=protected-access if self._outer_context or not util.IsLoopExit(op): op.graph.prevent_fetching(op) for x in op.outputs: -- cgit v1.2.3