aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-02-05 14:47:23 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-05 14:51:48 -0800
commit3374d3a6e5b9f77fa4229c41b233f1c0a229216f (patch)
tree35449a6ad0b3d69810ee40aff32af6bc1be97668
parent1c762f70caf7004470cdfa599b4eb7a76e5bcc78 (diff)
Automated g4 rollback of changelist 184573795
PiperOrigin-RevId: 184590080
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py30
-rw-r--r--tensorflow/python/ops/control_flow_ops.py25
2 files changed, 9 insertions, 46 deletions
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 15ff0ec09b..4fafc36014 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -704,36 +704,6 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
- def testWhileExternalControlDependencies(self):
- with self.test_session():
- v = variables.Variable(0.0)
- v.initializer.run()
- increment = v.assign_add(1.0)
-
- def body_fn(i):
- with ops.control_dependencies([increment]):
- return i + i
-
- result = control_flow_ops.while_loop(cond=lambda i: i < 1,
- body=body_fn, loop_vars=[1])
- result.eval()
- self.assertAllEqual(v.eval(), 1.0)
-
- def testWhileExternalControlDependenciesNoInput(self):
- with self.test_session():
- v = variables.Variable(0.0)
- v.initializer.run()
- increment = v.assign_add(1.0)
-
- def body_fn(unused_i):
- with ops.control_dependencies([increment]):
- return constant_op.constant(5, name="five")
-
- result = control_flow_ops.while_loop(cond=lambda i: i < 5,
- body=body_fn, loop_vars=[0])
- result.eval()
- self.assertAllEqual(v.eval(), 1.0)
-
def testWhileWithRefs_1(self):
with self.test_session() as sess:
x = variables.Variable(0)._ref() # pylint: disable=protected-access
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87ff0aba0a..bcd187d821 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -1631,13 +1631,10 @@ class ControlFlowContext(object):
ctxt = util.GetOutputContext(x)
if ctxt is not None and ctxt.GetWhileContext() == while_ctxt:
internal_control_inputs.append(x)
- external_control_inputs = []
if len(internal_control_inputs) != len(op.control_inputs):
- external_control_inputs = list(set(op.control_inputs)
- - set(internal_control_inputs))
op._remove_all_control_inputs()
op._add_control_inputs(internal_control_inputs)
- return internal_control_inputs, external_control_inputs
+ return internal_control_inputs
# pylint: enable=protected-access
@@ -2435,12 +2432,14 @@ class WhileContext(ControlFlowContext):
def _AddOpInternal(self, op):
"""Add `op` to the current context.
- We move any external control dependencies of the op to the loop pivot, to
- ensure they get executed.
+ In the case that op has only external data inputs, we remove all of its
+ external control inputs so all its inputs are in the same while loop
+ context. This is valid because op now has an Enter input that has all
+ the right control dependency.
"""
if not op.inputs:
# Remove any external control dependency on this op
- control_inputs, external_inputs = self._RemoveExternalControlEdges(op)
+ control_inputs = self._RemoveExternalControlEdges(op)
# Add a control edge from the control pivot to this op.
if not control_inputs:
# pylint: disable=protected-access
@@ -2453,20 +2452,14 @@ class WhileContext(ControlFlowContext):
x = op.inputs[index]
real_x = self.AddValue(x)
if real_x != x:
- op._update_input(index, real_x) # pylint: disable=protected-access
- # Remove any external control dependency on this op and move then to an
- # Enter node.
- _, external_inputs = self._RemoveExternalControlEdges(op)
+ op._update_input(index, real_x)
+ # Remove any external control dependency on this op.
+ self._RemoveExternalControlEdges(op)
# Add a control dependency to prevent loop invariants from
# enabling ops that should not be executed.
self._MaybeAddControlDependency(op)
for x in op.outputs:
self._values.add(x.name)
- if external_inputs:
- # Make the pivot depend on external control inputs
- pred = self._pivot_for_pred.op.inputs[0]
- assert util.IsLoopEnter(pred.op)
- pred.op._add_control_inputs(external_inputs) # pylint: disable=protected-access
if self._outer_context or not util.IsLoopExit(op):
op.graph.prevent_fetching(op)
for x in op.outputs: