aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/eager
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-05 16:32:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 16:40:27 -0700
commit1daaf0fabee1c59af00e14f358d08ac9f5390b9f (patch)
treeb89043c3399e12982ab99c216dada58a8aedcc5d /tensorflow/python/eager
parent12443341c1cf1c96fa187ca08dee2f2a9b9f618b (diff)
Orders non-resource-affecting stateful ops in defuns.
PiperOrigin-RevId: 215985679
Diffstat (limited to 'tensorflow/python/eager')
-rw-r--r--tensorflow/python/eager/function.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 2750461fb2..f06148b5d2 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1906,8 +1906,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1922,6 +1924,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access