aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/variable_scope.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/variable_scope.py')
-rw-r--r--tensorflow/python/ops/variable_scope.py292
1 files changed, 224 insertions, 68 deletions
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 47414c28af..f862b62fad 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -1,4 +1,4 @@
- # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -44,9 +44,11 @@ from tensorflow.python.util import function_utils
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-__all__ = ["AUTO_REUSE", "VariableScope", "get_variable_scope",
- "get_variable", "get_local_variable", "variable_scope",
- "variable_op_scope", "no_regularizer"]
+__all__ = [
+ "AUTO_REUSE", "VariableScope", "get_variable_scope", "get_variable",
+ "get_local_variable", "variable_scope", "variable_op_scope",
+ "no_regularizer", "VariableSynchronization", "VariableAggregation"
+]
class _PartitionInfo(object):
@@ -188,6 +190,38 @@ class _ReuseMode(enum.Enum):
# REUSE_FALSE = 2
# REUSE_TRUE = 3
+
+@tf_export("VariableSynchronization")
+class VariableSynchronization(enum.Enum):
+ """Indicates when a distributed variable will be synced."""
+
+ # Indicates that the synchronization will be determined by the current
+ # `DistributionStrategy` (eg. With `MirroredStrategy` this would be
+ # `ON_WRITE`).
+ AUTO = 0
+
+ # Indicates that there will only be one copy of the variable, so there is no
+ # need to sync.
+ NONE = 1
+
+ # Indicates that the variable will be aggregated across devices
+ # every time it is updated.
+ ON_WRITE = 2
+
+ # Indicates that the variable will be aggregated across devices
+ # when it is read (eg. when checkpointing or when evaluating an op that uses
+ # the variable).
+ ON_READ = 3
+
+
+@tf_export("VariableAggregation")
+class VariableAggregation(enum.Enum):
+ """Indicates how a distributed variable will be aggregated."""
+ NONE = 0
+ SUM = 1
+ MEAN = 2
+
+
AUTO_REUSE = _ReuseMode.AUTO_REUSE
tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
@@ -214,11 +248,23 @@ class _VariableStore(object):
self._partitioned_vars = {} # A dict of the stored PartitionedVariables.
self._store_eager_variables = False
- def get_variable(self, name, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- custom_getter=None, constraint=None):
+ def get_variable(self,
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ custom_getter=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with these parameters or create a new one.
If a variable with the given name is already stored, we return the stored
@@ -291,6 +337,14 @@ class _VariableStore(object):
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
Returns:
The created or existing `Variable` (or `PartitionedVariable`, if a
@@ -343,11 +397,22 @@ class _VariableStore(object):
# it to custom_getter.
# Note: the parameters of _true_getter, and their documentation, match
# *exactly* item-for-item with the docstring of this method.
- def _true_getter(name, shape=None, dtype=dtypes.float32, # pylint: disable=missing-docstring
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- partitioner=None, validate_shape=True, use_resource=None,
- constraint=None):
+ def _true_getter( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=True,
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
is_scalar = (shape is not None
and isinstance(shape, collections_lib.Sequence)
and not shape)
@@ -397,11 +462,20 @@ class _VariableStore(object):
"name was already created with partitioning?" % name)
return self._get_single_variable(
- name=name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, reuse=reuse,
- trainable=trainable, collections=collections,
- caching_device=caching_device, validate_shape=validate_shape,
- use_resource=use_resource, constraint=constraint)
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
@@ -420,6 +494,8 @@ class _VariableStore(object):
"partitioner": partitioner,
"validate_shape": validate_shape,
"use_resource": use_resource,
+ "synchronization": synchronization,
+ "aggregation": aggregation,
}
# `fn_args` can handle functions, `functools.partial`, `lambda`.
if "constraint" in function_utils.fn_args(custom_getter):
@@ -427,12 +503,21 @@ class _VariableStore(object):
return custom_getter(**custom_getter_kwargs)
else:
return _true_getter(
- name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer,
- reuse=reuse, trainable=trainable, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(
self, name, partitioner, shape=None, dtype=dtypes.float32,
@@ -693,7 +778,9 @@ class _VariableStore(object):
caching_device=None,
validate_shape=True,
use_resource=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Get or create a single Variable (e.g. a shard or entire variable).
See the documentation of get_variable above (ignore partitioning components)
@@ -713,6 +800,8 @@ class _VariableStore(object):
validate_shape: see get_variable.
use_resource: see get_variable.
constraint: see get_variable.
+ synchronization: see get_variable.
+ aggregation: see get_variable.
Returns:
A Variable. See documentation of get_variable above.
@@ -793,7 +882,9 @@ class _VariableStore(object):
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if context.executing_eagerly() and self._store_eager_variables:
if collections:
ops.add_to_collections(collections, v)
@@ -1052,7 +1143,9 @@ class VariableScope(object):
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Gets an existing variable with this name or create a new one."""
if regularizer is None:
regularizer = self._regularizer
@@ -1090,12 +1183,22 @@ class VariableScope(object):
if dtype is None:
dtype = self._dtype
return var_store.get_variable(
- full_name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, reuse=reuse, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ full_name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ reuse=reuse,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
def _get_partitioned_variable(self,
var_store,
@@ -1326,14 +1429,28 @@ def get_variable(name,
validate_shape=True,
use_resource=None,
custom_getter=None,
- constraint=None):
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
return get_variable_scope().get_variable(
- _get_default_variable_store(), name, shape=shape, dtype=dtype,
- initializer=initializer, regularizer=regularizer, trainable=trainable,
- collections=collections, caching_device=caching_device,
- partitioner=partitioner, validate_shape=validate_shape,
- use_resource=use_resource, custom_getter=custom_getter,
- constraint=constraint)
+ _get_default_variable_store(),
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=trainable,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ custom_getter=custom_getter,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
+
get_variable_or_local_docstring = (
"""%s
@@ -1430,29 +1547,44 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
@tf_export("get_local_variable")
-def get_local_variable(name,
- shape=None,
- dtype=None,
- initializer=None,
- regularizer=None,
- trainable=False, # pylint: disable=unused-argument
- collections=None,
- caching_device=None,
- partitioner=None,
- validate_shape=True,
- use_resource=None,
- custom_getter=None,
- constraint=None):
+def get_local_variable( # pylint: disable=missing-docstring
+ name,
+ shape=None,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=False, # pylint: disable=unused-argument
+ collections=None,
+ caching_device=None,
+ partitioner=None,
+ validate_shape=True,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE,
+ custom_getter=None,
+ constraint=None):
if collections:
collections += [ops.GraphKeys.LOCAL_VARIABLES]
else:
collections = [ops.GraphKeys.LOCAL_VARIABLES]
return get_variable(
- name, shape=shape, dtype=dtype, initializer=initializer,
- regularizer=regularizer, trainable=False, collections=collections,
- caching_device=caching_device, partitioner=partitioner,
- validate_shape=validate_shape, use_resource=use_resource,
- custom_getter=custom_getter, constraint=constraint)
+ name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ regularizer=regularizer,
+ trainable=False,
+ collections=collections,
+ caching_device=caching_device,
+ partitioner=partitioner,
+ validate_shape=validate_shape,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
+ custom_getter=custom_getter,
+ constraint=constraint)
+
+
get_local_variable.__doc__ = get_variable_or_local_docstring % (
"Gets an existing *local* variable or creates a new one.",
"Behavior is the same as in `get_variable`, except that variables are\n"
@@ -2214,6 +2346,12 @@ def default_variable_creator(next_creator=None, **kwargs):
dtype = kwargs.get("dtype", None)
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
+
+ # Enforce `ON_READ` variables to be not trainable.
+ synchronization = kwargs.pop("synchronization", VariableSynchronization.AUTO)
+ if synchronization == VariableSynchronization.ON_READ:
+ trainable = False
+
if use_resource is None:
use_resource = get_variable_scope().use_resource
if use_resource or (use_resource is None and context.executing_eagerly()):
@@ -2248,18 +2386,28 @@ def variable(initial_value=None,
name=None,
dtype=None,
constraint=None,
- use_resource=None):
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
- return previous_getter(initial_value=initial_value,
- trainable=trainable,
- collections=collections,
- validate_shape=validate_shape,
- caching_device=caching_device,
- name=name, dtype=dtype,
- constraint=constraint,
- use_resource=use_resource)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum None value.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ collections=collections,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ dtype=dtype,
+ constraint=constraint,
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
@tf_contextlib.contextmanager
@@ -2311,6 +2459,14 @@ def variable_creator_scope(variable_creator):
constraint: A constraint function to be applied to the variable after
updates by some algorithms.
use_resource: if True, a ResourceVariable is always created.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
This set may grow over time, so it's important the signature of creators is as
mentioned above.