diff options
author | A. Unique TensorFlower <nobody@tensorflow.org> | 2016-05-20 14:52:16 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-05-20 16:03:17 -0700 |
commit | cde02b84f55c0c3c09fb33a1ab5c58a476606d88 (patch) | |
tree | cc0196c668a1169af94736fd13ea8e46142a9539 /tensorflow/python/ops/template.py | |
parent | a824c3c87226345de45432c329c2b21e17854d0e (diff) |
Add create_scope_now option to make_template.
This is a non-breaking API change.
This gives the option of having a Template capture a scope at construction
time, rather than the time of first __call__. This can prove useful in a
situation where the Template is created but than passed into other code to be
called much later - it may be hard to be sure of the eventual scope name without
understanding all the lower level parts of the model, and potentially
configuration changes could lead to a different first __call__ location,
meaning a differently named variable, and potential problems saving and
loading.
By contrast, if create_scope_now_ is set to True, then wherever the template is
constructed defines the scope name. Lower level code which actually calls it is
free to switch around the order of __call__ without changing the variable name.
Change: 122874739
Diffstat (limited to 'tensorflow/python/ops/template.py')
-rw-r--r-- | tensorflow/python/ops/template.py | 72 |
1 files changed, 56 insertions, 16 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py index b2293ad591..c71f9d948a 100644 --- a/tensorflow/python/ops/template.py +++ b/tensorflow/python/ops/template.py @@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging __all__ = ["make_template"] -def make_template(name_, func_, **kwargs): +def make_template(name_, func_, create_scope_now_=False, **kwargs): """Given an arbitrary function, wrap it so that it does variable sharing. This wraps `func_` in a Template and partially evaluates it. Templates are @@ -98,29 +98,39 @@ def make_template(name_, func_, **kwargs): w2 = scale_by_y2(input2) ``` - Note: The full variable scope is captured at the time of the first call. + Depending on the value of `create_scope_now_`, the full variable scope may be + captured either at the time of first call or at the time of construction. If + this option is set to True, then all Tensors created by repeated calls to the + template will have an extra trailing _N+1 to their name, as the first time the + scope is entered in the Template constructor no Tensors are created. - Note: `name_` and `func_` have a following underscore to reduce the likelihood - of collisions with kwargs. + Note: `name_`, `func_` and `create_scope_now_` have a trailing underscore to + reduce the likelihood of collisions with kwargs. Args: name_: A name for the scope created by this template. If necessary, the name will be made unique by appending `_N` to the name. func_: The function to wrap. + create_scope_now_: Boolean controlling whether the scope should be created + when the template is constructed or when the template is called. Default + is False, meaning the scope is created when the template is called. **kwargs: Keyword arguments to apply to `func_`. Returns: - A function that will enter a `variable_scope` before calling `func_`. The - first time it is called, it will create a non-reusing scope so that the - variables will be unique. On each subsequent call, it will reuse those - variables. + A function to encapsulate a set of variables which should be created once + and reused. An enclosing scope will created, either where `make_template` + is called, or wherever the result is called, depending on the value of + `create_scope_now_`. Regardless of the value, the first time the template + is called it will enter the scope with no reuse, and call `func_` to create + variables, which are guaranteed to be unique. All subsequent calls will + re-enter the scope and reuse those variables. Raises: ValueError: if the name is None. """ if kwargs: func_ = functools.partial(func_, **kwargs) - return Template(name_, func_) + return Template(name_, func_, create_scope_now=create_scope_now_) def _skip_common_stack_elements(stacktrace, base_case): @@ -137,10 +147,13 @@ class Template(object): Templates are functions that create variables the first time they are called and reuse them thereafter. See `make_template` for full documentation. - Note: The full variable scope is captured at the time of the first call. + Note: By default, the full variable scope is captured at the time of first + call. If `create_scope_now_` is passed as True to the constructor, the full + scope will be captured there, but no variables will created until the first + call. """ - def __init__(self, name, func): + def __init__(self, name, func, create_scope_now=False): """Creates a template for the given function. Args: @@ -148,6 +161,15 @@ class Template(object): name will be made unique by appending `_N` to the it (see how `tf.variable_op_scope` treats the `default_name` for details). func: The function to apply each time. + create_scope_now: Whether to create the scope at Template construction + time, rather than first call. Defaults to false. Creating the scope at + construction time may be more convenient if the template is to passed + through much lower level code, and you want to be sure of the scope + name without knowing exactly where it will be first called. If set to + True, the scope will be created in the constructor, and all subsequent + times in __call__, leading to a trailing numeral being added to the + names of all created Tensors. If set to False, the scope will be created + at the first call location. Raises: ValueError: if the name is None. @@ -157,7 +179,14 @@ class Template(object): self._name = name if name is None: raise ValueError("name cannot be None.") - self._var_scope = None + if create_scope_now: + with variable_scope.variable_op_scope([], None, self._name) as vs: + self._var_scope = vs + else: + self._var_scope = None + # This variable keeps track of whether the template has been called yet, + # which is not the same as whether the scope has been created. + self._variables_created = False def _call_func(self, args, kwargs, check_for_new_variables): try: @@ -204,12 +233,23 @@ class Template(object): raise def __call__(self, *args, **kwargs): - # Capture the name of the variable_scope here because if we capture at - # construction, then name_scopes would have a '_N+1' suffix. if self._var_scope: - with variable_scope.variable_scope(self._var_scope, reuse=True): - return self._call_func(args, kwargs, check_for_new_variables=True) + if self._variables_created: + # This is not the first visit to __call__, so variables have already + # been created, and we want to reuse them. + with variable_scope.variable_scope(self._var_scope, reuse=True): + return self._call_func(args, kwargs, check_for_new_variables=True) + else: + # This is the first visit to __call__, but the scope has already been + # created in the constructor. Set _variables_created so that subsequent + # calls take the if branch above. + self._variables_created = True + with variable_scope.variable_scope(self._var_scope): + return self._call_func(args, kwargs, check_for_new_variables=False) else: + # The scope was not created at construction time, so create it here. + # Subsequent calls should reuse variables. + self._variables_created = True with variable_scope.variable_op_scope([], None, self._name) as vs: self._var_scope = vs return self._call_func(args, kwargs, check_for_new_variables=False) |