diff options
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 333 |
1 files changed, 333 insertions, 0 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py new file mode 100644 index 0000000000..c9c2cac0a5 --- /dev/null +++ b/tensorflow/python/ops/variable_scope.py @@ -0,0 +1,333 @@ +"""A class to store named variables and a scope operator to manage sharing.""" + +import contextlib + +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import types +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import logging + + +class _VariableStore(object): + """Variable store that carries a number of named Variables. + + New variable names and new variables can be created; all stored + variables are initialized with the initializer passed to __init__. + + Attributes: + vars: a dictionary with string names (same as passed in GetVar) as keys + and the corresponding TensorFlow Variables as values. + """ + + def __init__(self): + """Create a variable store.""" + self._vars = {} # A dictionary of the stored TensorFlow variables. + + def get_variable(self, name, shape=None, dtype=types.float32, + initializer=None, reuse=None, trainable=True, + collections=None): + """Gets an existing variable with these parameters or create a new one. + + If a variable with the given name is already stored, we return the stored + variable. Otherwise, we create a new one. + + Set `reuse` to `True` when you only want to reuse existing Variables. + Set `reuse` to `False` when you only want to create new Variables. + If `reuse` is `None` (the default), both new and existing variables are + returned. + + If initializer is `None` (the default), the default initializer passed in + the constructor is used. If that one is `None` too, we use a new + `UniformUnitScalingInitializer`. + + Args: + name: the name of the new or existing variable. + shape: shape of the new or existing variable. + dtype: type of the new or existing variable (defaults to `DT_FLOAT`). + initializer: initializer for the variable. + reuse: a Boolean or `None`. Controls reuse or creation of variables. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable). + collections: List of graph collections keys to add the Variable to. + Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable). + + Returns: + The created or existing variable. + + Raises: + ValueError: when creating a new variable and shape is not declared, + when reusing a variable and specifying a conflicting shape, + or when violating reuse during variable creation. + """ + should_check = reuse is not None + dtype = types.as_dtype(dtype) + shape = tensor_shape.as_shape(shape) + if name in self._vars: + # Here we handle the case when returning an existing variable. + if should_check and not reuse: + raise ValueError("Over-sharing: Variable %s already exists, disallowed." + " Did you mean to set reuse=True in VarScope?" % name) + found_var = self._vars[name] + if not shape.is_compatible_with(found_var.get_shape()): + raise ValueError("Trying to share variable %s, but specified shape %s" + " and found shape %s." % (name, str(shape), + str(found_var.get_shape()))) + if not dtype.is_compatible_with(found_var.dtype): + dtype_str = dtype.name + found_type_str = found_var.dtype.name + raise ValueError("Trying to share variable %s, but specified dtype %s" + " and found dtype %s." % (name, str(dtype_str), + str(found_type_str))) + return found_var + + # The code below handles only the case of creating a new variable. + if should_check and reuse: + raise ValueError("Under-sharing: Variable %s does not exist, disallowed." + " Did you mean to set reuse=None in VarScope?" % name) + if not shape.is_fully_defined(): + raise ValueError("Shape of a new variable (%s) must be fully defined, " + "but instead was %s." % (name, shape)) + if initializer is None: + initializer = init_ops.uniform_unit_scaling_initializer() + with ops.name_scope(name + "/Initializer/"): + init_val = initializer(shape.as_list(), dtype=dtype) + v = variables.Variable(init_val, name=name, trainable=trainable, + collections=collections) + self._vars[name] = v + logging.info("Created variable %s with shape %s and init %s", v.name, + format(shape), str(initializer)) + return v + + +class _VariableScope(object): + """Variable scope object to carry defaults to provide to get_variable. + + Many of the arguments we need for get_variable in a variable store are most + easily handled with a context. This object is used for the defaults. + + Attributes: + name: name of the current scope, used as prefix in get_variable. + initializer: default initializer passed to get_variable. + reuse: Boolean or None, setting the reuse in get_variable. + """ + + def __init__(self, reuse, name="", initializer=None): + self._name = name + self._initializer = initializer + self._reuse = reuse + + @property + def name(self): + return self._name + + @property + def reuse(self): + return self._reuse + + @property + def initializer(self): + return self._initializer + + def reuse_variables(self): + """Reuse variables in this scope.""" + self._reuse = True + + def set_initializer(self, initializer): + """Set initializer for this scope.""" + self._initializer = initializer + + def get_variable(self, var_store, name, shape=None, dtype=types.float32, + initializer=None, trainable=True, collections=None): + """Gets an existing variable with this name or create a new one.""" + if initializer is None and self._initializer: + initializer = self._initializer + full_name = self.name + "/" + name if self.name else name + # Variable names only depend on variable_scope (full_name here), + # not name_scope, so we reset it below for the time of variable creation. + with ops.name_scope(None): + return var_store.get_variable(full_name, shape, dtype, initializer, + self.reuse, trainable, collections) + + +_VARSTORE_KEY = ("__variable_store",) +_VARSCOPE_KEY = ("__varscope",) + + +def get_variable_scope(): + """Returns the current variable scope.""" + scope = ops.get_collection(_VARSCOPE_KEY) + if scope: # This collection has at most 1 element, the default scope at [0]. + return scope[0] + scope = _VariableScope(False) + ops.add_to_collection(_VARSCOPE_KEY, scope) + return scope + + +def _get_default_variable_store(): + store = ops.get_collection(_VARSTORE_KEY) + if store: + return store[0] + store = _VariableStore() + ops.add_to_collection(_VARSTORE_KEY, store) + return store + + +def get_variable(name, shape=None, dtype=types.float32, initializer=None, + trainable=True, collections=None): + """Gets an existing variable with these parameters or create a new one. + + This function prefixes the name with the current variable scope + and performs reuse checks. See the + [Variable Scope How To](../../how_tos/variable_scope/index.md) + for an extensive description of how reusing works. Here is a basic example: + + ```python + with tf.variable_scope("foo"): + v = get_variable("v", [1]) # v.name == "foo/v:0" + w = get_variable("w", [1]) # w.name == "foo/w:0" + with tf.variable_scope("foo", reuse=True) + v1 = get_variable("v") # The same as v above. + ``` + + If initializer is `None` (the default), the default initializer passed in + the constructor is used. If that one is `None` too, a + `UniformUnitScalingInitializer` will be used. + + Args: + name: the name of the new or existing variable. + shape: shape of the new or existing variable. + dtype: type of the new or existing variable (defaults to `DT_FLOAT`). + initializer: initializer for the variable if one is created. + trainable: If `True` also add the variable to the graph collection + `GraphKeys.TRAINABLE_VARIABLES` (see variables.Variable). + collections: List of graph collections keys to add the Variable to. + Defaults to `[GraphKeys.VARIABLES]` (see variables.Variable). + + Returns: + The created or existing variable. + + Raises: + ValueError: when creating a new variable and shape is not declared, + or when violating reuse during variable creation. Reuse is set inside + `variable_scope`. + """ + return get_variable_scope().get_variable(_get_default_variable_store(), name, + shape, dtype, initializer, + trainable, collections) + + +@contextlib.contextmanager +def variable_scope(name_or_scope, reuse=None, initializer=None): + """Returns a context for variable scope. + + Variable scope allows to create new variables and to share already created + ones while providing checks to not create or share by accident. For details, + see the [Variable Scope How To](../../how_tos/variable_scope/index.md), + here we present only a few basic examples. + + Simple example of how to create a new variable: + + ```python + with tf.variable_scope("foo"): + with tf.variable_scope("bar"): + v = tf.get_variable("v", [1]) + assert v.name == "foo/bar/v:0" + ``` + + Basic example of sharing a variable: + + ```python + with tf.variable_scope("foo"): + v = get_variable("v", [1]) + with tf.variable_scope("foo", reuse=True): + v1 = tf.get_variable("v", [1]) + assert v1 == v + ``` + + Sharing a variable by capturing a scope and setting reuse: + + ```python + with tf.variable_scope("foo") as scope. + v = get_variable("v", [1]) + scope.reuse_variables() + v1 = tf.get_variable("v", [1]) + assert v1 == v + ``` + + To prevent accidental sharing of variables, we raise an exception when + getting an existing variable in a non-reusing scope. + + ```python + with tf.variable_scope("foo") as scope. + v = get_variable("v", [1]) + v1 = tf.get_variable("v", [1]) + # Raises ValueError("... v already exists ..."). + ``` + + Similarly, we raise an exception when trying to get a variable that + does not exist in reuse mode. + + ```python + with tf.variable_scope("foo", reuse=True): + v = get_variable("v", [1]) + # Raises ValueError("... v does not exists ..."). + ``` + + Note that the `reuse` flag is inherited: if we open a reusing scope, + then all its sub-scopes become reusing as well. + + Args: + name_or_scope: `string` or `VariableScope`: the scope to open. + reuse: `True` or `None`; if `True`, we go into reuse mode for this scope as + well as all sub-scopes; if `None`, we just inherit the parent scope reuse. + initializer: default initializer for variables within this scope. + + Yields: + A scope that can be to captured and reused. + + Raises: + ValueError: when trying to reuse within a create scope, or create within + a reuse scope, or if reuse is not `None` or `True`. + TypeError: when the types of some arguments are not appropriate. + """ + if not isinstance(name_or_scope, (_VariableScope, basestring)): + raise TypeError("VariableScope: name_scope must be a string or " + "VariableScope.") + if reuse not in [None, True]: + raise ValueError("VariableScope reuse parameter must be True or None.") + if not reuse and isinstance(name_or_scope, (_VariableScope)): + logging.info("Passing VariableScope to a non-reusing scope, intended?") + if reuse and isinstance(name_or_scope, (basestring)): + logging.info("Re-using string-named scope, consider capturing as object.") + get_variable_scope() # Ensure that a default exists, then get a pointer. + default_varscope = ops.get_collection(_VARSCOPE_KEY) + try: + old = default_varscope[0] + reuse = reuse or old.reuse # Re-using is inherited by sub-scopes. + if isinstance(name_or_scope, _VariableScope): + # Handler for the case when we jump to a shared scope. + # In this case, we leave the current name_scope unchanged. + # We create a new VariableScope (default_varscope[0]) that contains + # a copy of the provided shared scope, possibly with changed reuse + # and initializer, if the user requested this. + default_varscope[0] = _VariableScope(reuse, name_or_scope.name, + name_or_scope.initializer) + if initializer: + default_varscope[0].set_initializer(initializer) + yield default_varscope[0] + else: + # Handler for the case when we just prolong current variable scope. + # In this case we prolong the current name_scope and create a new + # VariableScope with name extended by the provided one, and inherited + # reuse and initializer (except if the user provided values to set). + with ops.name_scope(name_or_scope): + new_name = old.name + "/" + name_or_scope if old.name else name_or_scope + default_varscope[0] = _VariableScope(reuse, name=new_name, + initializer=old.initializer) + if initializer: + default_varscope[0].set_initializer(initializer) + yield default_varscope[0] + finally: + default_varscope[0] = old |