aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops
diff options
context:
space:
mode:
authorGravatar Akshay Agrawal <akshayka@google.com>2018-01-10 13:08:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 13:14:34 -0800
commit47e06b7e53963cc45ee236cf86cc089185f47d32 (patch)
tree46c29753e626ac4b5fb19e2193027eb80f7a9020 /tensorflow/python/ops
parent3dc58f798660b98a0198b33edefb6f9f2aa7d827 (diff)
Support nesting EagerTemplate objects.
* Nesting is implemented by sharing a single EagerVariableStore among a top-level EagerTemplate and all children EagerTemplate objects that are nested underneath it. Variables added to an EagerTemplate object are also added to all EagerTemplate objects under which it is nested. * This change also simplifies the implementation of __call__ for both Template and EagerTemplate. PiperOrigin-RevId: 181506600
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r--tensorflow/python/ops/template.py150
-rw-r--r--tensorflow/python/ops/variable_scope.py11
2 files changed, 109 insertions, 52 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index e0cf1bff4f..169c4b5194 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -25,6 +25,7 @@ from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.deprecation import deprecated
@@ -259,20 +260,13 @@ class Template(object):
def __call__(self, *args, **kwargs):
if self._variable_scope:
- if self._variables_created:
- # This is not the first visit to __call__, so variables have already
- # been created, and we want to reuse them.
- with variable_scope.variable_scope(self._variable_scope, reuse=True):
- return self._call_func(args, kwargs, check_for_new_variables=True)
- else:
- # This is the first visit to __call__, but the scope has already been
- # created in the constructor. Set _variables_created after the inner
- # function is successfully called so that subsequent calls take the if
- # branch above.
- with variable_scope.variable_scope(self._variable_scope):
- result = self._call_func(args, kwargs, check_for_new_variables=False)
- self._variables_created = True
- return result
+ # Only reuse variables if they were already created.
+ with variable_scope.variable_scope(
+ self._variable_scope, reuse=self._variables_created):
+ result = self._call_func(
+ args, kwargs, check_for_new_variables=self._variables_created)
+ self._variables_created = True
+ return result
else:
# The scope was not created at construction time, so create it here.
# Subsequent calls should reuse variables.
@@ -372,6 +366,61 @@ class Template(object):
return self._variable_scope
+class _EagerTemplateVariableStore(object):
+ """Wrapper around EagerVariableStore to support nesting EagerTemplates.
+ """
+
+ def __init__(self, variable_scope_name):
+ self._variable_scope_name = variable_scope_name
+ default = variable_scope._get_default_variable_store() # pylint: disable=protected-access
+ if default._store_eager_variables: # pylint: disable=protected-access
+ self._eager_variable_store = variable_scope.EagerVariableStore(default)
+ else:
+ self._eager_variable_store = variable_scope.EagerVariableStore()
+
+ def set_variable_scope_name(self, variable_scope_name):
+ self._variable_scope_name = variable_scope_name
+
+ @tf_contextlib.contextmanager
+ def as_default(self):
+ try:
+ with self._eager_variable_store.as_default():
+ yield
+ finally:
+ # Each _EagerTemplateVariableStore object lives underneath a variable
+ # scope (see EagerTemplate.__call__). This variable scope's subscopes are
+ # closed when the EagerTemplate object returns from __call__. For
+ # top-level _EagerTemplateVariableStore objects, the variable store to
+ # which the variable scope is attached is different from the
+ # EagerVariableStore; as such it is necessary to close its subscopes
+ # here as well.
+ if self._variable_scope_name is None:
+ raise RuntimeError("A variable scope must be set before an "
+ "_EagerTemplateVariableStore object exits.")
+ self._eager_variable_store._store.close_variable_subscopes( # pylint: disable=protected-access
+ self._variable_scope_name)
+
+ def _variables_in_scope(self, variable_list):
+ if self._variable_scope_name is None:
+ raise RuntimeError(
+ "A variable scope must be set before variables can be accessed.")
+ return [
+ v for v in variable_list
+ if v.name.startswith(self._variable_scope_name + "/")
+ ]
+
+ def variables(self):
+ return self._variables_in_scope(self._eager_variable_store.variables())
+
+ def trainable_variables(self):
+ return self._variables_in_scope(
+ self._eager_variable_store.trainable_variables())
+
+ def non_trainable_variables(self):
+ return self._variables_in_scope(
+ self._eager_variable_store.non_trainable_variables())
+
+
class EagerTemplate(Template):
"""Wrap a function to aid in variable sharing in Eager mode.
@@ -416,26 +465,26 @@ class EagerTemplate(Template):
"{} objects can only be used when eager execution is enabled, use "
"tf.Template for graph construction".
format(type(self)))
- if unique_name:
+ if unique_name is not None:
raise ValueError("unique_name cannot be used in eager mode.")
super(EagerTemplate, self).__init__(name, func, create_scope_now,
unique_name, custom_getter)
- # Create an eager variable store only if the current variable store cannot
- # store eager variables. This should allow for correct nesting.
- default_vstore = variable_scope._get_default_variable_store() # pylint: disable=protected-access
- if default_vstore._store_eager_variables: # pylint: disable=protected-access
- raise ValueError("Nested EagerTemaplates are not currently supported.")
+ if self._variable_scope is not None:
+ variable_scope_name = self._variable_scope.name
else:
- self._eager_variable_store = variable_scope.EagerVariableStore()
+ # Defer setting the variable scope name until the variable scope
+ # is created in __call__.
+ variable_scope_name = None
+ self._template_store = _EagerTemplateVariableStore(variable_scope_name)
def _call_func(self, args, kwargs, check_for_new_variables):
try:
- vars_at_start = self._eager_variable_store.variables()
- trainable_at_start = self._eager_variable_store.trainable_variables()
+ vars_at_start = self._template_store.variables()
+ trainable_at_start = self._template_store.trainable_variables()
result = self._func(*args, **kwargs)
if check_for_new_variables:
- trainable_variables = self._eager_variable_store.trainable_variables()
+ trainable_variables = self._template_store.trainable_variables()
# If a variable that we intend to train is created as a side effect
# of creating a template, then that is almost certainly an error.
if len(trainable_at_start) != len(trainable_variables):
@@ -448,7 +497,7 @@ class EagerTemplate(Template):
# Non-trainable tracking variables are a legitimate reason why a new
# variable would be created, but it is a relatively advanced use-case,
# so log it.
- variables = self._eager_variable_store.variables()
+ variables = self._template_store.variables()
if len(vars_at_start) != len(variables):
logging.info("New variables created when calling a template after "
"the first time, perhaps you used tf.Variable when you "
@@ -472,26 +521,17 @@ class EagerTemplate(Template):
raise
def __call__(self, *args, **kwargs):
+ # In both branches below, the template store is installed as default after
+ # the variable scope is opened in order to ensure that templates nested at
+ # the same level correctly uniquify lower variable scope names.
if self._variable_scope:
- if self._variables_created:
- # This is not the first visit to __call__, so variables have already
- # been created, and we want to reuse them.
- with variable_scope.variable_scope(self._variable_scope,
- reuse=variable_scope.AUTO_REUSE):
- with self._eager_variable_store.as_default():
- return self._call_func(args, kwargs, check_for_new_variables=True)
- else:
- # This is the first visit to __call__, but the scope has already been
- # created in the constructor. Set _variables_created after the inner
- # function is successfully called so that subsequent calls take the if
- # branch above.
- with variable_scope.variable_scope(self._variable_scope,
- reuse=variable_scope.AUTO_REUSE):
- with self._eager_variable_store.as_default():
- result = self._call_func(args, kwargs,
- check_for_new_variables=False)
- self._variables_created = True
- return result
+ with variable_scope.variable_scope(
+ self._variable_scope, reuse=variable_scope.AUTO_REUSE):
+ with self._template_store.as_default():
+ result = self._call_func(
+ args, kwargs, check_for_new_variables=self._variables_created)
+ self._variables_created = True
+ return result
else:
# The scope was not created at construction time, so create it here.
# Subsequent calls should reuse variables.
@@ -499,9 +539,11 @@ class EagerTemplate(Template):
self._unique_name, self._name,
custom_getter=self._custom_getter) as vs:
self._variable_scope = vs
- with self._eager_variable_store.as_default():
- result = self._call_func(args, kwargs,
- check_for_new_variables=False)
+ # Because the scope was not created at construction time, the template
+ # store's variable scope name is unset; set it here.
+ self._template_store.set_variable_scope_name(vs.name)
+ with self._template_store.as_default():
+ result = self._call_func(args, kwargs, check_for_new_variables=False)
self._variables_created = True
return result
@@ -532,24 +574,32 @@ class EagerTemplate(Template):
def variables(self):
"""Returns the list of variables created by the Template."""
# Currently there is no local variable in Eager mode.
- return self._eager_variable_store.variables()
+ if not self._variables_created:
+ return []
+ return self._template_store.variables()
@property
def trainable_variables(self):
"""Returns the list of trainable variables created by the Template."""
# Currently there is no local variable in Eager mode.
- return self._eager_variable_store.trainable_variables()
+ if not self._variables_created:
+ return []
+ return self._template_store.trainable_variables()
@property
def non_trainable_variables(self):
"""Returns the list of non-trainable variables created by the Template."""
# Currently there is no local variable in Eager mode.
- return self._eager_variable_store.non_trainable_variables()
+ if not self._variables_created:
+ return []
+ return self._template_store.non_trainable_variables()
@property
def global_variables(self):
"""Returns the list of global variables created by the Template."""
# Currently there is no local variable in Eager mode.
+ if not self._variables_created:
+ return []
return self.variables
@property
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index e46ff529de..411d45ca1c 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1217,8 +1217,15 @@ class EagerVariableStore(object):
```
"""
- def __init__(self):
- self._store = _VariableStore()
+ def __init__(self, store=None):
+ if store is not None:
+ if not store._store_eager_variables: # pylint: disable=protected-access
+ raise ValueError("Cannot construct EagerVariableStore from a "
+ "VariableStore object that does not hold eager "
+ "variables.")
+ self._store = store
+ else:
+ self._store = _VariableStore()
self._store._store_eager_variables = True # pylint: disable=protected-access
def as_default(self):