aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Alexandre Passos <apassos@google.com>2018-10-05 16:32:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-05 16:40:27 -0700
commit1daaf0fabee1c59af00e14f358d08ac9f5390b9f (patch)
treeb89043c3399e12982ab99c216dada58a8aedcc5d /tensorflow/python
parent12443341c1cf1c96fa187ca08dee2f2a9b9f618b (diff)
Orders non-resource-affecting stateful ops in defuns.
PiperOrigin-RevId: 215985679
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/function.py7
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py13
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 2750461fb2..f06148b5d2 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -1906,8 +1906,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1922,6 +1924,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 4beddd00bb..2f19ecc0e6 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -306,6 +306,19 @@ class PrintV2Test(test.TestCase):
logging_ops.print_v2(tensor)
self.assertTrue((expected + "\n") in printed.contents())
+ def testPrintsOrderedInDefun(self):
+ with context.eager_mode():
+
+ @function.defun
+ def prints():
+ logging_ops.print_v2("A")
+ logging_ops.print_v2("B")
+ logging_ops.print_v2("C")
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ prints()
+ self.assertTrue(("A\nB\nC\n") in printed.contents())
+
@test_util.run_in_graph_and_eager_modes()
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
@function.defun