diff options
author | 2018-01-10 13:08:52 -0800 | |
---|---|---|
committer | 2018-01-10 13:14:34 -0800 | |
commit | 47e06b7e53963cc45ee236cf86cc089185f47d32 (patch) | |
tree | 46c29753e626ac4b5fb19e2193027eb80f7a9020 /tensorflow/python/ops | |
parent | 3dc58f798660b98a0198b33edefb6f9f2aa7d827 (diff) |
Support nesting EagerTemplate objects.
* Nesting is implemented by sharing a single EagerVariableStore among a top-level EagerTemplate and all children EagerTemplate objects that are nested underneath it. Variables added to an EagerTemplate object are also added to all EagerTemplate objects under which it is nested.
* This change also simplifies the implementation of __call__ for both Template and EagerTemplate.
PiperOrigin-RevId: 181506600
Diffstat (limited to 'tensorflow/python/ops')
-rw-r--r-- | tensorflow/python/ops/template.py | 150 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 11 |
2 files changed, 109 insertions, 52 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index e0cf1bff4f..169c4b5194 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -25,6 +25,7 @@ from tensorflow.python.eager import context from tensorflow.python.framework import ops from tensorflow.python.ops import variable_scope from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.util import tf_contextlib from tensorflow.python.util.deprecation import deprecated @@ -259,20 +260,13 @@ class Template(object): def __call__(self, *args, **kwargs): if self._variable_scope: - if self._variables_created: - # This is not the first visit to __call__, so variables have already - # been created, and we want to reuse them. - with variable_scope.variable_scope(self._variable_scope, reuse=True): - return self._call_func(args, kwargs, check_for_new_variables=True) - else: - # This is the first visit to __call__, but the scope has already been - # created in the constructor. Set _variables_created after the inner - # function is successfully called so that subsequent calls take the if - # branch above. - with variable_scope.variable_scope(self._variable_scope): - result = self._call_func(args, kwargs, check_for_new_variables=False) - self._variables_created = True - return result + # Only reuse variables if they were already created. + with variable_scope.variable_scope( + self._variable_scope, reuse=self._variables_created): + result = self._call_func( + args, kwargs, check_for_new_variables=self._variables_created) + self._variables_created = True + return result else: # The scope was not created at construction time, so create it here. # Subsequent calls should reuse variables. @@ -372,6 +366,61 @@ class Template(object): return self._variable_scope +class _EagerTemplateVariableStore(object): + """Wrapper around EagerVariableStore to support nesting EagerTemplates. + """ + + def __init__(self, variable_scope_name): + self._variable_scope_name = variable_scope_name + default = variable_scope._get_default_variable_store() # pylint: disable=protected-access + if default._store_eager_variables: # pylint: disable=protected-access + self._eager_variable_store = variable_scope.EagerVariableStore(default) + else: + self._eager_variable_store = variable_scope.EagerVariableStore() + + def set_variable_scope_name(self, variable_scope_name): + self._variable_scope_name = variable_scope_name + + @tf_contextlib.contextmanager + def as_default(self): + try: + with self._eager_variable_store.as_default(): + yield + finally: + # Each _EagerTemplateVariableStore object lives underneath a variable + # scope (see EagerTemplate.__call__). This variable scope's subscopes are + # closed when the EagerTemplate object returns from __call__. For + # top-level _EagerTemplateVariableStore objects, the variable store to + # which the variable scope is attached is different from the + # EagerVariableStore; as such it is necessary to close its subscopes + # here as well. + if self._variable_scope_name is None: + raise RuntimeError("A variable scope must be set before an " + "_EagerTemplateVariableStore object exits.") + self._eager_variable_store._store.close_variable_subscopes( # pylint: disable=protected-access + self._variable_scope_name) + + def _variables_in_scope(self, variable_list): + if self._variable_scope_name is None: + raise RuntimeError( + "A variable scope must be set before variables can be accessed.") + return [ + v for v in variable_list + if v.name.startswith(self._variable_scope_name + "/") + ] + + def variables(self): + return self._variables_in_scope(self._eager_variable_store.variables()) + + def trainable_variables(self): + return self._variables_in_scope( + self._eager_variable_store.trainable_variables()) + + def non_trainable_variables(self): + return self._variables_in_scope( + self._eager_variable_store.non_trainable_variables()) + + class EagerTemplate(Template): """Wrap a function to aid in variable sharing in Eager mode. @@ -416,26 +465,26 @@ class EagerTemplate(Template): "{} objects can only be used when eager execution is enabled, use " "tf.Template for graph construction". format(type(self))) - if unique_name: + if unique_name is not None: raise ValueError("unique_name cannot be used in eager mode.") super(EagerTemplate, self).__init__(name, func, create_scope_now, unique_name, custom_getter) - # Create an eager variable store only if the current variable store cannot - # store eager variables. This should allow for correct nesting. - default_vstore = variable_scope._get_default_variable_store() # pylint: disable=protected-access - if default_vstore._store_eager_variables: # pylint: disable=protected-access - raise ValueError("Nested EagerTemaplates are not currently supported.") + if self._variable_scope is not None: + variable_scope_name = self._variable_scope.name else: - self._eager_variable_store = variable_scope.EagerVariableStore() + # Defer setting the variable scope name until the variable scope + # is created in __call__. + variable_scope_name = None + self._template_store = _EagerTemplateVariableStore(variable_scope_name) def _call_func(self, args, kwargs, check_for_new_variables): try: - vars_at_start = self._eager_variable_store.variables() - trainable_at_start = self._eager_variable_store.trainable_variables() + vars_at_start = self._template_store.variables() + trainable_at_start = self._template_store.trainable_variables() result = self._func(*args, **kwargs) if check_for_new_variables: - trainable_variables = self._eager_variable_store.trainable_variables() + trainable_variables = self._template_store.trainable_variables() # If a variable that we intend to train is created as a side effect # of creating a template, then that is almost certainly an error. if len(trainable_at_start) != len(trainable_variables): @@ -448,7 +497,7 @@ class EagerTemplate(Template): # Non-trainable tracking variables are a legitimate reason why a new # variable would be created, but it is a relatively advanced use-case, # so log it. - variables = self._eager_variable_store.variables() + variables = self._template_store.variables() if len(vars_at_start) != len(variables): logging.info("New variables created when calling a template after " "the first time, perhaps you used tf.Variable when you " @@ -472,26 +521,17 @@ class EagerTemplate(Template): raise def __call__(self, *args, **kwargs): + # In both branches below, the template store is installed as default after + # the variable scope is opened in order to ensure that templates nested at + # the same level correctly uniquify lower variable scope names. if self._variable_scope: - if self._variables_created: - # This is not the first visit to __call__, so variables have already - # been created, and we want to reuse them. - with variable_scope.variable_scope(self._variable_scope, - reuse=variable_scope.AUTO_REUSE): - with self._eager_variable_store.as_default(): - return self._call_func(args, kwargs, check_for_new_variables=True) - else: - # This is the first visit to __call__, but the scope has already been - # created in the constructor. Set _variables_created after the inner - # function is successfully called so that subsequent calls take the if - # branch above. - with variable_scope.variable_scope(self._variable_scope, - reuse=variable_scope.AUTO_REUSE): - with self._eager_variable_store.as_default(): - result = self._call_func(args, kwargs, - check_for_new_variables=False) - self._variables_created = True - return result + with variable_scope.variable_scope( + self._variable_scope, reuse=variable_scope.AUTO_REUSE): + with self._template_store.as_default(): + result = self._call_func( + args, kwargs, check_for_new_variables=self._variables_created) + self._variables_created = True + return result else: # The scope was not created at construction time, so create it here. # Subsequent calls should reuse variables. @@ -499,9 +539,11 @@ class EagerTemplate(Template): self._unique_name, self._name, custom_getter=self._custom_getter) as vs: self._variable_scope = vs - with self._eager_variable_store.as_default(): - result = self._call_func(args, kwargs, - check_for_new_variables=False) + # Because the scope was not created at construction time, the template + # store's variable scope name is unset; set it here. + self._template_store.set_variable_scope_name(vs.name) + with self._template_store.as_default(): + result = self._call_func(args, kwargs, check_for_new_variables=False) self._variables_created = True return result @@ -532,24 +574,32 @@ class EagerTemplate(Template): def variables(self): """Returns the list of variables created by the Template.""" # Currently there is no local variable in Eager mode. - return self._eager_variable_store.variables() + if not self._variables_created: + return [] + return self._template_store.variables() @property def trainable_variables(self): """Returns the list of trainable variables created by the Template.""" # Currently there is no local variable in Eager mode. - return self._eager_variable_store.trainable_variables() + if not self._variables_created: + return [] + return self._template_store.trainable_variables() @property def non_trainable_variables(self): """Returns the list of non-trainable variables created by the Template.""" # Currently there is no local variable in Eager mode. - return self._eager_variable_store.non_trainable_variables() + if not self._variables_created: + return [] + return self._template_store.non_trainable_variables() @property def global_variables(self): """Returns the list of global variables created by the Template.""" # Currently there is no local variable in Eager mode. + if not self._variables_created: + return [] return self.variables @property diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index e46ff529de..411d45ca1c 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -1217,8 +1217,15 @@ class EagerVariableStore(object): ``` """ - def __init__(self): - self._store = _VariableStore() + def __init__(self, store=None): + if store is not None: + if not store._store_eager_variables: # pylint: disable=protected-access + raise ValueError("Cannot construct EagerVariableStore from a " + "VariableStore object that does not hold eager " + "variables.") + self._store = store + else: + self._store = _VariableStore() self._store._store_eager_variables = True # pylint: disable=protected-access def as_default(self): |