aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-07-09 16:58:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 17:04:47 -0700
commit855d2a723c24ce69ee993144c16506dbef12ed69 (patch)
tree8a0b8a9ce633bbd59eeef53ff5973cb2e49fb412 /tensorflow/python
parente0541b714d0df485a79ece616575ebb2a71b818c (diff)
Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
Migrate all usages of `tower_local_var_scope` to using the new args. PiperOrigin-RevId: 203855963
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/eager/graph_callable.py4
-rw-r--r--tensorflow/python/keras/engine/base_layer.py60
-rw-r--r--tensorflow/python/keras/layers/normalization.py24
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py68
-rw-r--r--tensorflow/python/layers/base.py41
-rw-r--r--tensorflow/python/layers/base_test.py26
-rw-r--r--tensorflow/python/ops/metrics_impl.py20
-rw-r--r--tensorflow/python/ops/variable_scope.py72
-rw-r--r--tensorflow/python/training/distribute.py56
9 files changed, 256 insertions, 115 deletions
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 848adf4fd3..2c6f04d8ad 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -118,7 +118,7 @@ class _VariableCapturingScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
@@ -156,7 +156,7 @@ class _VariableCapturingScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None, # pylint: disable=redefined-outer-name
partitioner=None,
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index 5c668345b7..26ea9cd797 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -460,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
- def add_weight(self, name, shape,
+ def add_weight(self,
+ name,
+ shape,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
constraint=None,
partitioner=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
getter=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -482,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
getter: Variable getter argument to be passed to the `Checkpointable` API.
Returns:
@@ -496,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
- ValueError: When giving unsupported dtype and no initializer.
+ ValueError: When giving unsupported dtype and no initializer or when
+ trainable has been set to True with synchronization set as `ON_READ`.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
@@ -505,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase):
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+
# Initialize variable when no initializer provided
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
@@ -532,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase):
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
@@ -1806,11 +1836,13 @@ def make_variable(name,
dtype=dtypes.float32,
initializer=None,
partition_info=None,
- trainable=True,
+ trainable=None,
caching_device=None,
validate_shape=True,
constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None): # pylint: disable=unused-argument
"""Temporary util to create a variable (relies on `variable_scope.variable`).
@@ -1836,11 +1868,21 @@ def make_variable(name,
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
caching_device: Passed to `vs.variable`.
validate_shape: Passed to `vs.variable`.
constraint: Constraint instance (callable).
use_resource: Whether to use a `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: Not handled at this time.
Returns:
@@ -1872,5 +1914,7 @@ def make_variable(name,
dtype=variable_dtype,
validate_shape=validate_shape,
constraint=constraint,
- use_resource=use_resource)
+ use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation)
return v
diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py
index 8b894ca6b1..58c8a8a66d 100644
--- a/tensorflow/python/keras/layers/normalization.py
+++ b/tensorflow/python/keras/layers/normalization.py
@@ -181,12 +181,6 @@ class BatchNormalization(Layer):
self.renorm_clipping = renorm_clipping
self.renorm_momentum = renorm_momentum
- def _add_tower_local_variable(self, *args, **kwargs):
- tower_context = distribute_lib.get_tower_context()
- with tower_context.tower_local_var_scope(
- variable_scope.VariableAggregation.MEAN):
- return self.add_weight(*args, **kwargs)
-
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if not input_shape.ndims:
@@ -314,19 +308,23 @@ class BatchNormalization(Layer):
self._scope.set_partitioner(None)
else:
partitioner = None
- self.moving_mean = self._add_tower_local_variable(
+ self.moving_mean = self.add_weight(
name='moving_mean',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_mean_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
- self.moving_variance = self._add_tower_local_variable(
+ self.moving_variance = self.add_weight(
name='moving_variance',
shape=param_shape,
dtype=param_dtype,
initializer=self.moving_variance_initializer,
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
if self.renorm:
# Create variables to maintain the moving mean and standard deviation.
@@ -337,12 +335,14 @@ class BatchNormalization(Layer):
# stack to be cleared. The nested ones use a `lambda` to set the desired
# device and ignore any devices that may be set by the custom getter.
def _renorm_variable(name, shape):
- var = self._add_tower_local_variable(
+ var = self.add_weight(
name=name,
shape=shape,
dtype=param_dtype,
initializer=init_ops.zeros_initializer(),
- trainable=False)
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=False,
+ aggregation=variable_scope.VariableAggregation.MEAN)
return var
with distribute_lib.get_distribution_strategy().colocate_vars_with(
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 054c6f9dd7..ae2a0ab29a 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -1054,7 +1054,7 @@ class VariableScopeTest(test.TestCase):
"testGetCollection_foo/testGetCollection_a:0"
])
- def testGetTrainableVariables(self):
+ def testGetTrainableVariablesWithGetVariable(self):
with self.test_session():
_ = variable_scope.get_variable("testGetTrainableVariables_a", [])
with variable_scope.variable_scope(
@@ -1062,10 +1062,72 @@ class VariableScopeTest(test.TestCase):
_ = variable_scope.get_variable("testGetTrainableVariables_b", [])
_ = variable_scope.get_variable(
"testGetTrainableVariables_c", [], trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_d", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
self.assertEqual(
[v.name for v in scope.trainable_variables()],
- ["testGetTrainableVariables_foo/"
- "testGetTrainableVariables_b:0"])
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.get_variable(
+ "testGetTrainableVariables_e", [],
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
+ def testGetTrainableVariablesWithVariable(self):
+ with self.test_session():
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a")
+ with variable_scope.variable_scope(
+ "testGetTrainableVariables_foo") as scope:
+ _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b")
+ _ = variable_scope.variable(
+ 1.0, name="testGetTrainableVariables_c", trainable=False)
+
+ # sync `ON_READ` sets trainable=False
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_d",
+ synchronization=variable_scope.VariableSynchronization.ON_READ)
+ self.assertEqual(
+ [v.name for v in scope.trainable_variables()],
+ ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"])
+
+ # All other sync values sets trainable=True
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_WRITE)
+ self.assertEqual([v.name for v in scope.trainable_variables()], [
+ "testGetTrainableVariables_foo/testGetTrainableVariables_b:0",
+ "testGetTrainableVariables_foo/testGetTrainableVariables_e:0"
+ ])
+
+ with self.assertRaisesRegexp(
+ ValueError, "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ."):
+ _ = variable_scope.variable(
+ 1.0,
+ name="testGetTrainableVariables_e",
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
def testGetGlobalVariables(self):
with self.test_session():
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index b8969a41ab..cf13b52617 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -152,10 +152,17 @@ class Layer(base_layer.Layer):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_weight(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=None,
+ constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -170,9 +177,19 @@ class Layer(base_layer.Layer):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -190,7 +207,21 @@ class Layer(base_layer.Layer):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When trainable has been set to True with synchronization
+ set as `ON_READ`.
"""
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
@@ -240,6 +271,8 @@ class Layer(base_layer.Layer):
constraint=constraint,
partitioner=partitioner,
use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
getter=vs.get_variable)
if regularizer:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 298e96e711..d2443db665 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase):
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ # Test that sync `ON_READ` variables are defaulted to be non-trainable.
+ variable_3 = layer.add_variable(
+ 'sync_on_read_var', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3])
+
+ def testInvalidTrainableSynchronizationCombination(self):
+ layer = base_layers.Layer(name='my_layer')
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.'):
+ _ = layer.add_variable(
+ 'v', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
def testReusePartitionedVaraiblesAndRegularizers(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
partitioner = partitioned_variables.fixed_size_partitioner(3)
@@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase):
partitioner=partitioner,
reuse=reuse):
layer = base_layers.Layer(name='my_layer')
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_part_var', [4, 4],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py
index bfd225b0d8..3aedeb6acd 100644
--- a/tensorflow/python/ops/metrics_impl.py
+++ b/tensorflow/python/ops/metrics_impl.py
@@ -73,16 +73,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None):
A (non-trainable) variable initialized to zero, or if inside a
`DistributionStrategy` scope a tower-local variable container.
"""
- with distribute_lib.get_tower_context().tower_local_var_scope(
- variable_scope.VariableAggregation.SUM):
- # Note that "tower local" implies trainable=False.
- return variable_scope.variable(
- lambda: array_ops.zeros(shape, dtype),
- collections=[
- ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
- ],
- validate_shape=validate_shape,
- name=name)
+ # Note that synchronization "ON_READ" implies trainable=False.
+ return variable_scope.variable(
+ lambda: array_ops.zeros(shape, dtype),
+ collections=[
+ ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
+ ],
+ validate_shape=validate_shape,
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM,
+ name=name)
def _remove_squeezable_dimensions(predictions, labels, weights):
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index 1e06bf07d5..77f67c18ee 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -255,7 +255,7 @@ class _VariableStore(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -300,6 +300,8 @@ class _VariableStore(object):
forced to be False.
trainable: If `True` also add the variable to the graph collection
`GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys to add the `Variable` to.
Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`).
caching_device: Optional device string or function describing where the
@@ -341,7 +343,8 @@ class _VariableStore(object):
aggregated. Accepted values are constants defined in the class
@{tf.VariableSynchronization}. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
- when to synchronize.
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
@{tf.VariableAggregation}.
@@ -404,7 +407,7 @@ class _VariableStore(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -477,6 +480,10 @@ class _VariableStore(object):
synchronization=synchronization,
aggregation=aggregation)
+ # Set trainable value based on synchronization value.
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
if custom_getter is not None:
# Handle backwards compatibility with getter arguments that were added
# to the API after users started writing custom getters.
@@ -519,11 +526,20 @@ class _VariableStore(object):
synchronization=synchronization,
aggregation=aggregation)
- def _get_partitioned_variable(
- self, name, partitioner, shape=None, dtype=dtypes.float32,
- initializer=None, regularizer=None, reuse=None,
- trainable=True, collections=None, caching_device=None,
- validate_shape=True, use_resource=None, constraint=None):
+ def _get_partitioned_variable(self,
+ name,
+ partitioner,
+ shape=None,
+ dtype=dtypes.float32,
+ initializer=None,
+ regularizer=None,
+ reuse=None,
+ trainable=None,
+ collections=None,
+ caching_device=None,
+ validate_shape=True,
+ use_resource=None,
+ constraint=None):
"""Gets or creates a sharded variable list with these parameters.
The `partitioner` must be a callable that accepts a fully defined
@@ -773,7 +789,7 @@ class _VariableStore(object):
regularizer=None,
partition_info=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
validate_shape=True,
@@ -1136,7 +1152,7 @@ class VariableScope(object):
initializer=None,
regularizer=None,
reuse=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1207,7 +1223,7 @@ class VariableScope(object):
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -1422,7 +1438,7 @@ def get_variable(name,
dtype=None,
initializer=None,
regularizer=None,
- trainable=True,
+ trainable=None,
collections=None,
caching_device=None,
partitioner=None,
@@ -2334,11 +2350,28 @@ def _compute_slice_dim_and_shape(full_shape, slicing):
return slice_dim, slice_shape
+def _get_trainable_value(synchronization, trainable):
+ """Computes the trainable value based on the given arguments."""
+ if synchronization == VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ "Synchronization value can be set to "
+ "VariableSynchronization.ON_READ only for non-trainable variables. "
+ "You have specified trainable=True and "
+ "synchronization=VariableSynchronization.ON_READ.")
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
+ return trainable
+
+
def default_variable_creator(next_creator=None, **kwargs):
"""Default variable creator."""
assert next_creator is None
initial_value = kwargs.get("initial_value", None)
- trainable = kwargs.get("trainable", True)
+ trainable = kwargs.get("trainable", None)
collections = kwargs.get("collections", None)
validate_shape = kwargs.get("validate_shape", True)
caching_device = kwargs.get("caching_device", None)
@@ -2347,10 +2380,10 @@ def default_variable_creator(next_creator=None, **kwargs):
constraint = kwargs.get("constraint", None)
use_resource = kwargs.get("use_resource", None)
- # Enforce `ON_READ` variables to be not trainable.
+ # Set trainable value based on synchronization value.
synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
- if synchronization == VariableSynchronization.ON_READ:
- trainable = False
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
if use_resource is None:
use_resource = get_variable_scope().use_resource
@@ -2379,7 +2412,7 @@ def _make_getter(captured_getter, captured_previous):
def variable(initial_value=None,
- trainable=True,
+ trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
@@ -2441,6 +2474,8 @@ def variable_creator_scope(variable_creator):
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
+ `trainable` defaults to `True` unless `synchronization` is
+ set to `ON_READ`.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
@@ -2463,7 +2498,8 @@ def variable_creator_scope(variable_creator):
aggregated. Accepted values are constants defined in the class
@{tf.VariableSynchronization}. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
- when to synchronize.
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
@{tf.VariableAggregation}.
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index d33fd7376a..c719045c7f 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -614,48 +614,6 @@ class DistributionStrategy(object):
# Note: should support "colocate_with" argument.
raise NotImplementedError("must be implemented in descendants")
- def tower_local_var_scope(self, aggregation):
- """Inside this scope, new variables will not be mirrored.
-
- There will still be one component variable per tower, but there is
- no requirement that they stay in sync. Instead, when saving them
- or calling `read_var()`, we use the value that results when
- calling `reduce()` on all the towers' variables.
-
- Note: tower-local implies not trainable. Instead, it is expected
- that each tower will directly update (using `assign_add()` or
- whatever) its local variable instance but only the aggregated
- value (accessible using `read_var()`) will be exported from the
- model. When it is acceptable to only aggregate on export, we
- greatly reduce communication overhead by using tower-local
- variables.
-
- Note: All component variables will be initialized to the same
- value, using the initialization expression from the first tower.
- The values will match even if the initialization expression uses
- random numbers.
-
- Args:
- aggregation: Indicates how a variable will be aggregated. Accepted values
- are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}.
-
- Returns:
- A context manager.
- """
- # TODO(psv): Remove this after adding support for synchronization and
- # aggregation parameters in get_variable() and mirrored strategy.
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["use_resource"] = True
-
- # Set synchronization to be ON_READ for tower local variables.
- kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ
- kwargs["aggregation"] = aggregation
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def read_var(self, v):
"""Reads the value of a variable.
@@ -1103,10 +1061,6 @@ class TowerContext(object):
finally:
_pop_per_thread_mode()
- def tower_local_var_scope(self, aggregation):
- """Alias for distribution_strategy.tower_local_var_scope()."""
- return self._distribution_strategy.tower_local_var_scope(aggregation)
-
@property
def is_single_tower(self):
"""Returns whether there is a single tower or multiple."""
@@ -1158,16 +1112,6 @@ class _DefaultDistributionStrategy(DistributionStrategy):
return _CurrentDistributionContext(
self, variable_scope.variable_creator_scope(creator))
- def tower_local_var_scope(self, aggregation):
- """Does not set to resource variables."""
- def create_tower_local_variable(next_creator, *args, **kwargs):
- _require_distribution_strategy_scope(self)
- kwargs["trainable"] = False
- return next_creator(*args, **kwargs)
-
- _require_distribution_strategy_scope(self)
- return variable_scope.variable_creator_scope(create_tower_local_variable)
-
def colocate_vars_with(self, colocate_with_variable):
"""Does not require `self.scope`."""
_require_distribution_strategy_scope(self)