diff options
author | Alexandre Passos <apassos@google.com> | 2018-10-05 16:32:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-05 16:40:27 -0700 |
commit | 1daaf0fabee1c59af00e14f358d08ac9f5390b9f (patch) | |
tree | b89043c3399e12982ab99c216dada58a8aedcc5d /tensorflow/python | |
parent | 12443341c1cf1c96fa187ca08dee2f2a9b9f618b (diff) |
Orders non-resource-affecting stateful ops in defuns.
PiperOrigin-RevId: 215985679
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/function.py | 7 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/logging_ops_test.py | 13 |
2 files changed, 20 insertions, 0 deletions
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index 2750461fb2..f06148b5d2 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -1906,8 +1906,10 @@ class AutomaticControlDependencies(object): last_op_using_resource_tensor[inp] = op ops_which_must_run = set([op]) continue + found_resource = False for inp in op.inputs: if inp.dtype == dtypes_module.resource: + found_resource = True # Deal with switches, finally. if inp.op.type == "Switch": self._process_switch(inp.op, ops_which_must_run, @@ -1922,6 +1924,11 @@ class AutomaticControlDependencies(object): if inp in merge_for_resource: merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access last_op_using_resource_tensor[inp] = op + if (op.op_def.is_stateful and not found_resource + and op._control_flow_context is None): # pylint: disable=protected-access + if None in last_op_using_resource_tensor: + op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access + last_op_using_resource_tensor[None] = op control_inputs = [c for c in control_inputs if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access op._add_control_inputs(control_inputs) # pylint: disable=protected-access diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py index 4beddd00bb..2f19ecc0e6 100644 --- a/tensorflow/python/kernel_tests/logging_ops_test.py +++ b/tensorflow/python/kernel_tests/logging_ops_test.py @@ -306,6 +306,19 @@ class PrintV2Test(test.TestCase): logging_ops.print_v2(tensor) self.assertTrue((expected + "\n") in printed.contents()) + def testPrintsOrderedInDefun(self): + with context.eager_mode(): + + @function.defun + def prints(): + logging_ops.print_v2("A") + logging_ops.print_v2("B") + logging_ops.print_v2("C") + + with self.captureWritesToStream(sys.stderr) as printed: + prints() + self.assertTrue(("A\nB\nC\n") in printed.contents()) + @test_util.run_in_graph_and_eager_modes() def testPrintInDefunWithoutExplicitEvalOfPrint(self): @function.defun |