aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py339
1 files changed, 240 insertions, 99 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 47414c28af..aca44bcd44 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1,4 +1,4 @@
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
- "get_variable", "get_local_variable", "variable_scope",
- "variable_op_scope", "no_regularizer"]
+__all__ = [
+ "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
+ "get_local_variable", "variable_scope", "variable_op_scope",
+ "no_regularizer", "VariableSynchronization", "VariableAggregation"
+]
class _PartitionInfo(object):
@@ -188,6 +190,11 @@ class _ReuseMode(enum.Enum):
# REUSE_FALSE = 2
# REUSE_TRUE = 3
+
+# TODO(apassos) remove these forwarding symbols.
+VariableSynchronization = variables.VariableSynchronization # pylint: disable=invalid-name
+VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
+
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
@@ -214,11 +221,23 @@ class _VariableStore(object):
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._store_eager_variables = False
- def get_variable(self, name, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- custom_getter=None, constraint=None):
+ def get_variable(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ custom_getter=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the stored
@@ -254,6 +273,8 @@ class _VariableStore(object):
forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys to add the `Variable` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where the
@@ -291,6 +312,15 @@ class _VariableStore(object):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -343,11 +373,22 @@ class _VariableStore(object):
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
- def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- constraint=None):
+ def _true_getter( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
is_scalar = (shape is not None
and isinstance(shape, collections_lib.Sequence)
and not shape)
@@ -397,11 +438,24 @@ class _VariableStore(object):
"name was already created with partitioning?" % name)
return self._get_single_variable(
- name=name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, reuse=reuse,
- trainable=trainable, collections=collections,
- caching_device=caching_device, validate_shape=validate_shape,
- use_resource=use_resource, constraint=constraint)
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ # Set trainable value based on synchronization value.
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
@@ -420,6 +474,8 @@ class _VariableStore(object):
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
+ "synchronization": synchronization,
+ "aggregation": aggregation,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
if "constraint" in function_utils.fn_args(custom_getter):
@@ -427,18 +483,36 @@ class _VariableStore(object):
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
- name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- reuse=reuse, trainable=trainable, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- constraint=constraint)
-
- def _get_partitioned_variable(
- self, name, partitioner, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- validate_shape=True, use_resource=None, constraint=None):
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+ def _get_partitioned_variable(self,
+ name,
+ partitioner,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None):
"""Gets or creates a sharded variable list with these parameters.
The `partitioner` must be a callable that accepts a fully defined
@@ -688,12 +762,14 @@ class _VariableStore(object):
regularizer=None,
partition_info=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
validate_shape=True,
use_resource=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
@@ -713,6 +789,8 @@ class _VariableStore(object):
validate_shape: see get_variable.
use_resource: see get_variable.
constraint: see get_variable.
+ synchronization: see get_variable.
+ aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
@@ -793,7 +871,9 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if context.executing_eagerly() and self._store_eager_variables:
if collections:
ops.add_to_collections(collections, v)
@@ -1045,14 +1125,16 @@ class VariableScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with this name or create a new one."""
if regularizer is None:
regularizer = self._regularizer
@@ -1090,12 +1172,22 @@ class VariableScope(object):
if dtype is None:
dtype = self._dtype
return var_store.get_variable(
- full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=reuse, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ full_name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(self,
var_store,
@@ -1104,7 +1196,7 @@ class VariableScope(object):
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1319,21 +1411,35 @@ def get_variable(name,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
return get_variable_scope().get_variable(
- _get_default_variable_store(), name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ _get_default_variable_store(),
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+
get_variable_or_local_docstring = (
"""%s
@@ -1430,29 +1536,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
@tf_export("get_local_variable")
-def get_local_variable(name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=False, # pylint: disable=unused-argument
- collections=None,
- caching_device=None,
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- custom_getter=None,
- constraint=None):
+def get_local_variable( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=False, # pylint: disable=unused-argument
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE,
+ custom_getter=None,
+ constraint=None):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
return get_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, trainable=False, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- custom_getter=custom_getter, constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=False,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ custom_getter=custom_getter,
+ constraint=constraint)
+
+
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
@@ -2202,37 +2323,64 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
return slice_dim, slice_shape
+def _get_trainable_value(synchronization, trainable):
+ """Computes the trainable value based on the given arguments."""
+ if synchronization == VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ.")
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+ return trainable
+
+
def default_variable_creator(next_creator=None, **kwargs):
"""Default variable creator."""
assert next_creator is None
initial_value = kwargs.get("initial_value", None)
- trainable = kwargs.get("trainable", True)
+ trainable = kwargs.get("trainable", None)
collections = kwargs.get("collections", None)
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
dtype = kwargs.get("dtype", None)
+ expected_shape = kwargs.get("expected_shape", None)
+ import_scope = kwargs.get("import_scope", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
if use_resource is None:
use_resource = get_variable_scope().use_resource
- if use_resource or (use_resource is None and context.executing_eagerly()):
+ use_resource = use_resource or context.executing_eagerly()
+ if use_resource:
return resource_variable_ops.ResourceVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
- elif not use_resource and context.executing_eagerly():
- raise RuntimeError(
- "VariableScope should use resource variable when eager execution is"
- " enabled, but use_resource is False."
- )
+ constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
else:
- return variables.Variable(
+ return variables.RefVariable(
initial_value=initial_value, trainable=trainable,
collections=collections, validate_shape=validate_shape,
caching_device=caching_device, name=name, dtype=dtype,
- constraint=constraint)
+ constraint=constraint, variable_def=variable_def,
+ expected_shape=expected_shape, import_scope=import_scope)
+
+
+variables.default_variable_creator = default_variable_creator
def _make_getter(captured_getter, captured_previous):
@@ -2240,26 +2388,8 @@ def _make_getter(captured_getter, captured_previous):
return lambda **kwargs: captured_getter(captured_previous, **kwargs)
-def variable(initial_value=None,
- trainable=True,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- dtype=None,
- constraint=None,
- use_resource=None):
- previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
- for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
- previous_getter = _make_getter(getter, previous_getter)
- return previous_getter(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name, dtype=dtype,
- constraint=constraint,
- use_resource=use_resource)
+# TODO(apassos) remove forwarding symbol
+variable = variables.Variable
@tf_contextlib.contextmanager
@@ -2293,6 +2423,8 @@ def variable_creator_scope(variable_creator):
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
@@ -2311,6 +2443,15 @@ def variable_creator_scope(variable_creator):
constraint: A constraint function to be applied to the variable after
updates by some algorithms.
use_resource: if True, a ResourceVariable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
This set may grow over time, so it's important the signature of creators is as
mentioned above.