aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/template.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-05-20 14:52:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-05-20 16:03:17 -0700
commitcde02b84f55c0c3c09fb33a1ab5c58a476606d88 (patch)
treecc0196c668a1169af94736fd13ea8e46142a9539 /tensorflow/python/ops/template.py
parenta824c3c87226345de45432c329c2b21e17854d0e (diff)
Add create_scope_now option to make_template.
This is a non-breaking API change. This gives the option of having a Template capture a scope at construction time, rather than the time of first __call__. This can prove useful in a situation where the Template is created but than passed into other code to be called much later - it may be hard to be sure of the eventual scope name without understanding all the lower level parts of the model, and potentially configuration changes could lead to a different first __call__ location, meaning a differently named variable, and potential problems saving and loading. By contrast, if create_scope_now_ is set to True, then wherever the template is constructed defines the scope name. Lower level code which actually calls it is free to switch around the order of __call__ without changing the variable name. Change: 122874739
Diffstat (limited to 'tensorflow/python/ops/template.py')
-rw-r--r--tensorflow/python/ops/template.py72
1 files changed, 56 insertions, 16 deletions
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index b2293ad591..c71f9d948a 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -29,7 +29,7 @@ from tensorflow.python.platform import tf_logging as logging
__all__ = ["make_template"]
-def make_template(name_, func_, **kwargs):
+def make_template(name_, func_, create_scope_now_=False, **kwargs):
"""Given an arbitrary function, wrap it so that it does variable sharing.
This wraps `func_` in a Template and partially evaluates it. Templates are
@@ -98,29 +98,39 @@ def make_template(name_, func_, **kwargs):
w2 = scale_by_y2(input2)
```
- Note: The full variable scope is captured at the time of the first call.
+ Depending on the value of `create_scope_now_`, the full variable scope may be
+ captured either at the time of first call or at the time of construction. If
+ this option is set to True, then all Tensors created by repeated calls to the
+ template will have an extra trailing _N+1 to their name, as the first time the
+ scope is entered in the Template constructor no Tensors are created.
- Note: `name_` and `func_` have a following underscore to reduce the likelihood
- of collisions with kwargs.
+ Note: `name_`, `func_` and `create_scope_now_` have a trailing underscore to
+ reduce the likelihood of collisions with kwargs.
Args:
name_: A name for the scope created by this template. If necessary, the name
will be made unique by appending `_N` to the name.
func_: The function to wrap.
+ create_scope_now_: Boolean controlling whether the scope should be created
+ when the template is constructed or when the template is called. Default
+ is False, meaning the scope is created when the template is called.
**kwargs: Keyword arguments to apply to `func_`.
Returns:
- A function that will enter a `variable_scope` before calling `func_`. The
- first time it is called, it will create a non-reusing scope so that the
- variables will be unique. On each subsequent call, it will reuse those
- variables.
+ A function to encapsulate a set of variables which should be created once
+ and reused. An enclosing scope will created, either where `make_template`
+ is called, or wherever the result is called, depending on the value of
+ `create_scope_now_`. Regardless of the value, the first time the template
+ is called it will enter the scope with no reuse, and call `func_` to create
+ variables, which are guaranteed to be unique. All subsequent calls will
+ re-enter the scope and reuse those variables.
Raises:
ValueError: if the name is None.
"""
if kwargs:
func_ = functools.partial(func_, **kwargs)
- return Template(name_, func_)
+ return Template(name_, func_, create_scope_now=create_scope_now_)
def _skip_common_stack_elements(stacktrace, base_case):
@@ -137,10 +147,13 @@ class Template(object):
Templates are functions that create variables the first time they are called
and reuse them thereafter. See `make_template` for full documentation.
- Note: The full variable scope is captured at the time of the first call.
+ Note: By default, the full variable scope is captured at the time of first
+ call. If `create_scope_now_` is passed as True to the constructor, the full
+ scope will be captured there, but no variables will created until the first
+ call.
"""
- def __init__(self, name, func):
+ def __init__(self, name, func, create_scope_now=False):
"""Creates a template for the given function.
Args:
@@ -148,6 +161,15 @@ class Template(object):
name will be made unique by appending `_N` to the it (see how
`tf.variable_op_scope` treats the `default_name` for details).
func: The function to apply each time.
+ create_scope_now: Whether to create the scope at Template construction
+ time, rather than first call. Defaults to false. Creating the scope at
+ construction time may be more convenient if the template is to passed
+ through much lower level code, and you want to be sure of the scope
+ name without knowing exactly where it will be first called. If set to
+ True, the scope will be created in the constructor, and all subsequent
+ times in __call__, leading to a trailing numeral being added to the
+ names of all created Tensors. If set to False, the scope will be created
+ at the first call location.
Raises:
ValueError: if the name is None.
@@ -157,7 +179,14 @@ class Template(object):
self._name = name
if name is None:
raise ValueError("name cannot be None.")
- self._var_scope = None
+ if create_scope_now:
+ with variable_scope.variable_op_scope([], None, self._name) as vs:
+ self._var_scope = vs
+ else:
+ self._var_scope = None
+ # This variable keeps track of whether the template has been called yet,
+ # which is not the same as whether the scope has been created.
+ self._variables_created = False
def _call_func(self, args, kwargs, check_for_new_variables):
try:
@@ -204,12 +233,23 @@ class Template(object):
raise
def __call__(self, *args, **kwargs):
- # Capture the name of the variable_scope here because if we capture at
- # construction, then name_scopes would have a '_N+1' suffix.
if self._var_scope:
- with variable_scope.variable_scope(self._var_scope, reuse=True):
- return self._call_func(args, kwargs, check_for_new_variables=True)
+ if self._variables_created:
+ # This is not the first visit to __call__, so variables have already
+ # been created, and we want to reuse them.
+ with variable_scope.variable_scope(self._var_scope, reuse=True):
+ return self._call_func(args, kwargs, check_for_new_variables=True)
+ else:
+ # This is the first visit to __call__, but the scope has already been
+ # created in the constructor. Set _variables_created so that subsequent
+ # calls take the if branch above.
+ self._variables_created = True
+ with variable_scope.variable_scope(self._var_scope):
+ return self._call_func(args, kwargs, check_for_new_variables=False)
else:
+ # The scope was not created at construction time, so create it here.
+ # Subsequent calls should reuse variables.
+ self._variables_created = True
with variable_scope.variable_op_scope([], None, self._name) as vs:
self._var_scope = vs
return self._call_func(args, kwargs, check_for_new_variables=False)