diff options
-rw-r--r-- | tensorflow/python/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/python/eager/context.py | 40 | ||||
-rw-r--r-- | tensorflow/python/framework/ops.py | 61 | ||||
-rw-r--r-- | tensorflow/python/framework/ops_test.py | 191 |
4 files changed, 293 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index e77fba4a4c..45383eda99 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -1108,6 +1108,7 @@ py_test( ":variables", "//tensorflow/core:protos_all_py", "//tensorflow/python/eager:context", + "//tensorflow/python/eager:function", ], ) diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py index 415416cfae..8aec242f1d 100644 --- a/tensorflow/python/eager/context.py +++ b/tensorflow/python/eager/context.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import collections import contextlib import copy import random @@ -62,6 +63,41 @@ class _EagerContext(threading.local): self.scalar_cache = {} +ContextStackEntry = collections.namedtuple( + "ContextStackEntry", ["is_building_function", "enter_context_fn"]) + + +class ContextStack(threading.local): + """A thread-local stack of context switches.""" + + def __init__(self): + super(ContextStack, self).__init__() + self.stack = [] + + def push(self, is_building_function, enter_context_fn): + """Push metadata about a context switch onto the stack. + + A context switch can take one of two forms: installing a graph as the + default graph, or entering the eager context. + + Args: + is_building_function: (bool.) Whether the context is building a function. + enter_context_fn: (function.) A callable that executes the context switch. + For example, `graph.as_default` or `eager_mode`. + """ + + self.stack.append( + ContextStackEntry(is_building_function, enter_context_fn)) + + def pop(self): + """Pop the stack.""" + + self.stack.pop() + + +context_stack = ContextStack() + + # TODO(agarwal): rename to EagerContext / EagerRuntime ? # TODO(agarwal): consider keeping the corresponding Graph here. class Context(object): @@ -183,10 +219,14 @@ class Context(object): ctx = self._eager_context old_mode = ctx.mode ctx.mode = mode + if mode == EAGER_MODE: + context_stack.push(False, eager_mode) try: yield finally: ctx.mode = old_mode + if mode == EAGER_MODE: + context_stack.pop() def in_graph_mode(self): """Returns True if current thread is in GRAPH mode.""" diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py index 836f09fba8..947a9e49cc 100644 --- a/tensorflow/python/framework/ops.py +++ b/tensorflow/python/framework/ops.py @@ -4850,10 +4850,71 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access super(_DefaultGraphStack, self).reset() self._global_default_graph = None + @tf_contextlib.contextmanager + def get_controller(self, default): + try: + context.context_stack.push(default.building_function, default.as_default) + with super(_DefaultGraphStack, self).get_controller(default) as g: + yield g + finally: + context.context_stack.pop() + _default_graph_stack = _DefaultGraphStack() +# pylint: disable=g-doc-return-or-yield,line-too-long +@tf_contextlib.contextmanager +def init_scope(): + """A context manager that lifts ops out of control-flow scopes and function-building graphs. + + There is often a need to lift variable initialization ops out of control-flow + scopes, function-building graphs, and gradient tapes. Entering an + `init_scope` is a mechanism for satisfying these desiderata. In particular, + entering an `init_scope` has three effects: + + (1) All control dependencies are cleared the moment the scope is entered; + this is equivalent to entering the context manager returned from + `control_dependencies(None)`, which has the side-effect of exiting + control-flow scopes like `tf.cond` and `tf.while_loop`. + + (2) All operations that are created while the scope is active are lifted + into the lowest context on the `context_stack` that is not building a + graph function. Here, a context is defined as either a graph or an eager + context. Every context switch, i.e., every installation of a graph as + the default graph and every switch into eager mode, is logged in a + thread-local stack called the `context_stack`; the log entry for a + context switch is popped from the stack when the context is exited. + Entering an `init_scope` is equivalent to crawling up the + `context_stack`, finding the first context that is not building a graph + function, and entering it. + + (3) The gradient tape is paused while the scope is active. + """ +# pylint: enable=g-doc-return-or-yield,line-too-long + + outer_context = None + if not context.context_stack.stack: + # This is correct because of an invariant: the stack is + # empty if and only if eager execution has not been enabled. + outer_context = get_default_graph().as_default + else: + for stack_entry in reversed(context.context_stack.stack): + if not stack_entry.is_building_function: + outer_context = stack_entry.enter_context_fn + break + + if outer_context is None: + raise AssertionError("All graphs are building functions, and no " + "eager context was previously active.") + + try: + with outer_context(), control_dependencies(None), tape.stop_recording(): + yield + finally: + pass + + def enable_eager_execution(config=None, device_policy=None): """Enables, for the rest of the lifetime of this program, eager execution. diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py index f04f0cc56d..92d42c1807 100644 --- a/tensorflow/python/framework/ops_test.py +++ b/tensorflow/python/framework/ops_test.py @@ -26,6 +26,7 @@ from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session from tensorflow.python.eager import context +from tensorflow.python.eager import function as eager_function from tensorflow.python.framework import common_shapes from tensorflow.python.framework import constant_op from tensorflow.python.framework import device as pydev @@ -43,6 +44,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resources from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -1868,6 +1870,195 @@ class OpScopeTest(test_util.TensorFlowTestCase): self._testGraphElements([a, variable, b]) +class InitScopeTest(test_util.TensorFlowTestCase): + + def testClearsControlDependencies(self): + g = ops.Graph() + a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + with g.as_default(): + with g.control_dependencies([a_1]): + with g.control_dependencies([a_2]): + with ops.init_scope(): + with g.control_dependencies([a_3]): + with g.control_dependencies([a_4]): + # deps [a_3, a_4] + b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps = [a_3] + b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to None + b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1, a_2] + b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + # deps back to [a_1] + b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + with ops.init_scope(): + # deps are None again + b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32]) + + self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs) + self.assertItemsEqual([a_3.op], b_3.op.control_inputs) + self.assertItemsEqual([], b_none.op.control_inputs) + self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs) + self.assertItemsEqual([a_1.op], b_1.op.control_inputs) + self.assertItemsEqual([], b_none2.op.control_inputs) + + def testLiftsOpsFromFunctions(self): + g0 = ops.Graph() + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + g2 = ops.Graph() + g2._building_function = True # pylint: disable=protected-access + + with g0.as_default(): + with g1.as_default(): + with g2.as_default(): + with ops.init_scope(): + _ = constant_op.constant(1.0) + + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 1) + + def testComposes(self): + g0 = ops.Graph() + g1 = ops.Graph() + g1._building_function = True # pylint: disable=protected-access + g2 = ops.Graph() + g2._building_function = True # pylint: disable=protected-access + g3 = ops.Graph() + g3._building_function = False # pylint: disable=protected-access + + with g0.as_default(): + with g1.as_default(): + with ops.init_scope(): + # This op should be lifted into g0. + _ = constant_op.constant(1.0) + self.assertIs(g0, ops.get_default_graph()) + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 1) + with g2.as_default(): + with ops.init_scope(): + # This op should be lifted into g0. + _ = constant_op.constant(1.0) + self.assertIs(g0, ops.get_default_graph()) + with g3.as_default(): + with ops.init_scope(): + # This op should be lifted into g3, because g3 is not building a + # function. + _ = constant_op.constant(1.0) + self.assertIs(g3, ops.get_default_graph()) + + self.assertEqual(len(g3.get_operations()), 1) + self.assertEqual(len(g2.get_operations()), 0) + self.assertEqual(len(g1.get_operations()), 0) + self.assertEqual(len(g0.get_operations()), 2) + + def testEscapesToEagerContext(self): + g = ops.Graph() + g._building_function = True # pylint: disable=protected-access + with context.eager_mode(): + with context.graph_mode(): + with g.as_default(): + with ops.init_scope(): + # Because g is building a function, init_scope should + # escape out to the eager context. + self.assertTrue(context.in_eager_mode()) + # g should be reinstated as the default graph, and the + # graph context should be re-entered. + self.assertIs(g, ops.get_default_graph()) + self.assertTrue(context.in_graph_mode()) + + def testAllGraphsBuildingFunctionsRaisesError(self): + g = ops.Graph() + g._building_function = True # pylint: disable=protected-access + with g.as_default(): + with self.assertRaises(AssertionError): + with ops.init_scope(): + pass + + def testStaysInEagerWhenOnlyEagerContextActive(self): + with context.eager_mode(): + with ops.init_scope(): + self.assertTrue(context.eager_mode()) + self.assertTrue(context.eager_mode()) + + def testEscapesDefunWhenInEagerMode(self): + + def function_with_variables(): + with ops.init_scope(): + v = resource_variable_ops.ResourceVariable(3) + return v.assign_add(1) + + with context.eager_mode(): + # Each invocation of function_with_variables recreates a variable. + self.assertEqual(4, int(function_with_variables())) + self.assertEqual(4, int(function_with_variables())) + + compiled = eager_function.defun(function_with_variables) + # The init_scope in function_with_variables lifts the variable out + # of the graph function constructed by defun; hence, + # compiled now appears to be stateful. + self.assertEqual(4, int(compiled())) + self.assertEqual(5, int(compiled())) + + def testEscapesDefunWhenInGraphMode(self): + def function_with_variables(name): + with ops.init_scope(): + _ = variable_scope.get_variable(name, shape=(1,)) + + g = ops.Graph() + with g.as_default(): + with self.test_session(): + # First ensure that graphs that are not building functions are + # not escaped. + function_with_variables("foo") + with self.assertRaisesRegexp(ValueError, + r"Variable foo already exists.*"): + # This will fail because reuse is not set to True. + function_with_variables("foo") + + compiled = eager_function.defun(function_with_variables) + compiled("bar") + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) + + # The second call to `compiled` should not create variables: the + # init_scope has lifted the variable creation code out of the defun. + compiled("bar") + self.assertEqual( + len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2) + + def testEscapesNestedDefun(self): + + def inner_function(): + with ops.init_scope(): + v = resource_variable_ops.ResourceVariable(1) + return v.assign_add(2) + + def outer_function(inner=None): + with ops.init_scope(): + v0 = resource_variable_ops.ResourceVariable(0) + return v0.assign_add(1) + inner() + + with context.eager_mode(): + # Each invocation of outer_function recreates variables. + self.assertEqual(4, int(outer_function(inner=inner_function))) + self.assertEqual(4, int(outer_function(inner=inner_function))) + + compiled_inner = eager_function.defun(inner_function) + compiled_outer = eager_function.defun(outer_function) + # The init_scope lifts variables out of the graph functions + # constructed by defun; hence, compiled_outer should now appear to be + # stateful. + self.assertEqual(4, int(compiled_outer(inner=compiled_inner))) + self.assertEqual(7, int(compiled_outer(inner=compiled_inner))) + + @test_util.with_c_api class GraphTest(test_util.TensorFlowTestCase): |