aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2017-12-14 15:05:12 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 15:08:16 -0800
commitdea51b668ba9858c914a7fcb0fb6fdc3df132d72 (patch)
treef51f7bd93f2a13880d2a87abf51708f7470a3afc
parente92f85e78573fbf88accdf4b76535b0c70e7f674 (diff)
Add `init_scope`, a scope for wrapping variable creation and initialization.
There is often a need to lift variable initialization ops out of control flow contexts, graphs that are building functions, and gradient tapes. Entering an `init_scope` is a mechanism for satisfying these desiderata. In particular, entering an `init_scope` has three effects: (1) All control dependencies are cleared the moment the scope is entered; this is equivalent to entering the manager returned from `control_dependencies(None)`, which is how we exit control-flow contexts like `tf.while_loop` and `tf.cond`. (2) All operations that are created while the scope is active are lifted into the lowest context on the `context_stack` that is not building a graph function. (3) The gradient tape is paused while the scope is active. In (2), a context is defined as either a graph or an eager context. Every context switch, i.e., every installation of a graph as the default graph and every switch into eager mode, is logged in a thread-local stack called the `context_stack`; the log entry for a context switch is popped from the stack when the context is exited. Entering an `init_scope`, with respect to (2), is equivalent to crawling up the `context_stack`, finding the first context that is not building a graph function, and entering it. PiperOrigin-RevId: 179104270
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/eager/context.py40
-rw-r--r--tensorflow/python/framework/ops.py61
-rw-r--r--tensorflow/python/framework/ops_test.py191
4 files changed, 293 insertions, 0 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index e77fba4a4c..45383eda99 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1108,6 +1108,7 @@ py_test(
":variables",
"//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:function",
],
)
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 415416cfae..8aec242f1d 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import contextlib
import copy
import random
@@ -62,6 +63,41 @@ class _EagerContext(threading.local):
self.scalar_cache = {}
+ContextStackEntry = collections.namedtuple(
+ "ContextStackEntry", ["is_building_function", "enter_context_fn"])
+
+
+class ContextStack(threading.local):
+ """A thread-local stack of context switches."""
+
+ def __init__(self):
+ super(ContextStack, self).__init__()
+ self.stack = []
+
+ def push(self, is_building_function, enter_context_fn):
+ """Push metadata about a context switch onto the stack.
+
+ A context switch can take one of two forms: installing a graph as the
+ default graph, or entering the eager context.
+
+ Args:
+ is_building_function: (bool.) Whether the context is building a function.
+ enter_context_fn: (function.) A callable that executes the context switch.
+ For example, `graph.as_default` or `eager_mode`.
+ """
+
+ self.stack.append(
+ ContextStackEntry(is_building_function, enter_context_fn))
+
+ def pop(self):
+ """Pop the stack."""
+
+ self.stack.pop()
+
+
+context_stack = ContextStack()
+
+
# TODO(agarwal): rename to EagerContext / EagerRuntime ?
# TODO(agarwal): consider keeping the corresponding Graph here.
class Context(object):
@@ -183,10 +219,14 @@ class Context(object):
ctx = self._eager_context
old_mode = ctx.mode
ctx.mode = mode
+ if mode == EAGER_MODE:
+ context_stack.push(False, eager_mode)
try:
yield
finally:
ctx.mode = old_mode
+ if mode == EAGER_MODE:
+ context_stack.pop()
def in_graph_mode(self):
"""Returns True if current thread is in GRAPH mode."""
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 836f09fba8..947a9e49cc 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4850,10 +4850,71 @@ class _DefaultGraphStack(_DefaultStack): # pylint: disable=protected-access
super(_DefaultGraphStack, self).reset()
self._global_default_graph = None
+ @tf_contextlib.contextmanager
+ def get_controller(self, default):
+ try:
+ context.context_stack.push(default.building_function, default.as_default)
+ with super(_DefaultGraphStack, self).get_controller(default) as g:
+ yield g
+ finally:
+ context.context_stack.pop()
+
_default_graph_stack = _DefaultGraphStack()
+# pylint: disable=g-doc-return-or-yield,line-too-long
+@tf_contextlib.contextmanager
+def init_scope():
+ """A context manager that lifts ops out of control-flow scopes and function-building graphs.
+
+ There is often a need to lift variable initialization ops out of control-flow
+ scopes, function-building graphs, and gradient tapes. Entering an
+ `init_scope` is a mechanism for satisfying these desiderata. In particular,
+ entering an `init_scope` has three effects:
+
+ (1) All control dependencies are cleared the moment the scope is entered;
+ this is equivalent to entering the context manager returned from
+ `control_dependencies(None)`, which has the side-effect of exiting
+ control-flow scopes like `tf.cond` and `tf.while_loop`.
+
+ (2) All operations that are created while the scope is active are lifted
+ into the lowest context on the `context_stack` that is not building a
+ graph function. Here, a context is defined as either a graph or an eager
+ context. Every context switch, i.e., every installation of a graph as
+ the default graph and every switch into eager mode, is logged in a
+ thread-local stack called the `context_stack`; the log entry for a
+ context switch is popped from the stack when the context is exited.
+ Entering an `init_scope` is equivalent to crawling up the
+ `context_stack`, finding the first context that is not building a graph
+ function, and entering it.
+
+ (3) The gradient tape is paused while the scope is active.
+ """
+# pylint: enable=g-doc-return-or-yield,line-too-long
+
+ outer_context = None
+ if not context.context_stack.stack:
+ # This is correct because of an invariant: the stack is
+ # empty if and only if eager execution has not been enabled.
+ outer_context = get_default_graph().as_default
+ else:
+ for stack_entry in reversed(context.context_stack.stack):
+ if not stack_entry.is_building_function:
+ outer_context = stack_entry.enter_context_fn
+ break
+
+ if outer_context is None:
+ raise AssertionError("All graphs are building functions, and no "
+ "eager context was previously active.")
+
+ try:
+ with outer_context(), control_dependencies(None), tape.stop_recording():
+ yield
+ finally:
+ pass
+
+
def enable_eager_execution(config=None, device_policy=None):
"""Enables, for the rest of the lifetime of this program, eager execution.
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index f04f0cc56d..92d42c1807 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -26,6 +26,7 @@ from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import common_shapes
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device as pydev
@@ -43,6 +44,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import resources
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -1868,6 +1870,195 @@ class OpScopeTest(test_util.TensorFlowTestCase):
self._testGraphElements([a, variable, b])
+class InitScopeTest(test_util.TensorFlowTestCase):
+
+ def testClearsControlDependencies(self):
+ g = ops.Graph()
+ a_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ a_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ with g.as_default():
+ with g.control_dependencies([a_1]):
+ with g.control_dependencies([a_2]):
+ with ops.init_scope():
+ with g.control_dependencies([a_3]):
+ with g.control_dependencies([a_4]):
+ # deps [a_3, a_4]
+ b_3_4 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps = [a_3]
+ b_3 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to None
+ b_none = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1, a_2]
+ b_1_2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ # deps back to [a_1]
+ b_1 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+ with ops.init_scope():
+ # deps are None again
+ b_none2 = _apply_op(g, "FloatOutput", [], [dtypes.float32])
+
+ self.assertItemsEqual([a_3.op, a_4.op], b_3_4.op.control_inputs)
+ self.assertItemsEqual([a_3.op], b_3.op.control_inputs)
+ self.assertItemsEqual([], b_none.op.control_inputs)
+ self.assertItemsEqual([a_1.op, a_2.op], b_1_2.op.control_inputs)
+ self.assertItemsEqual([a_1.op], b_1.op.control_inputs)
+ self.assertItemsEqual([], b_none2.op.control_inputs)
+
+ def testLiftsOpsFromFunctions(self):
+ g0 = ops.Graph()
+ g1 = ops.Graph()
+ g1._building_function = True # pylint: disable=protected-access
+ g2 = ops.Graph()
+ g2._building_function = True # pylint: disable=protected-access
+
+ with g0.as_default():
+ with g1.as_default():
+ with g2.as_default():
+ with ops.init_scope():
+ _ = constant_op.constant(1.0)
+
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 1)
+
+ def testComposes(self):
+ g0 = ops.Graph()
+ g1 = ops.Graph()
+ g1._building_function = True # pylint: disable=protected-access
+ g2 = ops.Graph()
+ g2._building_function = True # pylint: disable=protected-access
+ g3 = ops.Graph()
+ g3._building_function = False # pylint: disable=protected-access
+
+ with g0.as_default():
+ with g1.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g0.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g0, ops.get_default_graph())
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 1)
+ with g2.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g0.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g0, ops.get_default_graph())
+ with g3.as_default():
+ with ops.init_scope():
+ # This op should be lifted into g3, because g3 is not building a
+ # function.
+ _ = constant_op.constant(1.0)
+ self.assertIs(g3, ops.get_default_graph())
+
+ self.assertEqual(len(g3.get_operations()), 1)
+ self.assertEqual(len(g2.get_operations()), 0)
+ self.assertEqual(len(g1.get_operations()), 0)
+ self.assertEqual(len(g0.get_operations()), 2)
+
+ def testEscapesToEagerContext(self):
+ g = ops.Graph()
+ g._building_function = True # pylint: disable=protected-access
+ with context.eager_mode():
+ with context.graph_mode():
+ with g.as_default():
+ with ops.init_scope():
+ # Because g is building a function, init_scope should
+ # escape out to the eager context.
+ self.assertTrue(context.in_eager_mode())
+ # g should be reinstated as the default graph, and the
+ # graph context should be re-entered.
+ self.assertIs(g, ops.get_default_graph())
+ self.assertTrue(context.in_graph_mode())
+
+ def testAllGraphsBuildingFunctionsRaisesError(self):
+ g = ops.Graph()
+ g._building_function = True # pylint: disable=protected-access
+ with g.as_default():
+ with self.assertRaises(AssertionError):
+ with ops.init_scope():
+ pass
+
+ def testStaysInEagerWhenOnlyEagerContextActive(self):
+ with context.eager_mode():
+ with ops.init_scope():
+ self.assertTrue(context.eager_mode())
+ self.assertTrue(context.eager_mode())
+
+ def testEscapesDefunWhenInEagerMode(self):
+
+ def function_with_variables():
+ with ops.init_scope():
+ v = resource_variable_ops.ResourceVariable(3)
+ return v.assign_add(1)
+
+ with context.eager_mode():
+ # Each invocation of function_with_variables recreates a variable.
+ self.assertEqual(4, int(function_with_variables()))
+ self.assertEqual(4, int(function_with_variables()))
+
+ compiled = eager_function.defun(function_with_variables)
+ # The init_scope in function_with_variables lifts the variable out
+ # of the graph function constructed by defun; hence,
+ # compiled now appears to be stateful.
+ self.assertEqual(4, int(compiled()))
+ self.assertEqual(5, int(compiled()))
+
+ def testEscapesDefunWhenInGraphMode(self):
+ def function_with_variables(name):
+ with ops.init_scope():
+ _ = variable_scope.get_variable(name, shape=(1,))
+
+ g = ops.Graph()
+ with g.as_default():
+ with self.test_session():
+ # First ensure that graphs that are not building functions are
+ # not escaped.
+ function_with_variables("foo")
+ with self.assertRaisesRegexp(ValueError,
+ r"Variable foo already exists.*"):
+ # This will fail because reuse is not set to True.
+ function_with_variables("foo")
+
+ compiled = eager_function.defun(function_with_variables)
+ compiled("bar")
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
+
+ # The second call to `compiled` should not create variables: the
+ # init_scope has lifted the variable creation code out of the defun.
+ compiled("bar")
+ self.assertEqual(
+ len(ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)), 2)
+
+ def testEscapesNestedDefun(self):
+
+ def inner_function():
+ with ops.init_scope():
+ v = resource_variable_ops.ResourceVariable(1)
+ return v.assign_add(2)
+
+ def outer_function(inner=None):
+ with ops.init_scope():
+ v0 = resource_variable_ops.ResourceVariable(0)
+ return v0.assign_add(1) + inner()
+
+ with context.eager_mode():
+ # Each invocation of outer_function recreates variables.
+ self.assertEqual(4, int(outer_function(inner=inner_function)))
+ self.assertEqual(4, int(outer_function(inner=inner_function)))
+
+ compiled_inner = eager_function.defun(inner_function)
+ compiled_outer = eager_function.defun(outer_function)
+ # The init_scope lifts variables out of the graph functions
+ # constructed by defun; hence, compiled_outer should now appear to be
+ # stateful.
+ self.assertEqual(4, int(compiled_outer(inner=compiled_inner)))
+ self.assertEqual(7, int(compiled_outer(inner=compiled_inner)))
+
+
@test_util.with_c_api
class GraphTest(test_util.TensorFlowTestCase):