aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers
diff options
context:
space:
mode:
authorGravatar Pavithra Vijay <psv@google.com>2018-07-09 16:58:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-09 17:04:47 -0700
commit855d2a723c24ce69ee993144c16506dbef12ed69 (patch)
tree8a0b8a9ce633bbd59eeef53ff5973cb2e49fb412 /tensorflow/python/layers
parente0541b714d0df485a79ece616575ebb2a71b818c (diff)
Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
Migrate all usages of `tower_local_var_scope` to using the new args. PiperOrigin-RevId: 203855963
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r--tensorflow/python/layers/base.py41
-rw-r--r--tensorflow/python/layers/base_test.py26
2 files changed, 61 insertions, 6 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index b8969a41ab..cf13b52617 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -152,10 +152,17 @@ class Layer(base_layer.Layer):
scope, default_name=self._base_name) as captured_scope:
self._scope = captured_scope
- def add_weight(self, name, shape, dtype=None,
- initializer=None, regularizer=None,
- trainable=True, constraint=None,
+ def add_weight(self,
+ name,
+ shape,
+ dtype=None,
+ initializer=None,
+ regularizer=None,
+ trainable=None,
+ constraint=None,
use_resource=None,
+ synchronization=vs.VariableSynchronization.AUTO,
+ aggregation=vs.VariableAggregation.NONE,
partitioner=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
@@ -170,9 +177,19 @@ class Layer(base_layer.Layer):
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
- marked as non-trainable.
+ marked as non-trainable. `trainable` defaults to `True` unless
+ `synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
use_resource: Whether to use `ResourceVariable`.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ @{tf.VariableSynchronization}. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ @{tf.VariableAggregation}.
partitioner: (optional) partitioner instance (callable). If
provided, when the requested variable is created it will be split
into multiple partitions according to `partitioner`. In this case,
@@ -190,7 +207,21 @@ class Layer(base_layer.Layer):
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
+ ValueError: When trainable has been set to True with synchronization
+ set as `ON_READ`.
"""
+ if synchronization == vs.VariableSynchronization.ON_READ:
+ if trainable:
+ raise ValueError(
+ 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.')
+ else:
+ # Set trainable to be false when variable is to be synced on read.
+ trainable = False
+ elif trainable is None:
+ trainable = True
def _should_add_regularizer(variable, existing_variable_set):
if isinstance(variable, tf_variables.PartitionedVariable):
@@ -240,6 +271,8 @@ class Layer(base_layer.Layer):
constraint=constraint,
partitioner=partitioner,
use_resource=use_resource,
+ synchronization=synchronization,
+ aggregation=aggregation,
getter=vs.get_variable)
if regularizer:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 298e96e711..d2443db665 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase):
# regularizers only supported in GRAPH mode.
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_var', [2, 2],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ # Test that sync `ON_READ` variables are defaulted to be non-trainable.
+ variable_3 = layer.add_variable(
+ 'sync_on_read_var', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ aggregation=variable_scope.VariableAggregation.SUM)
+ self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3])
+
+ def testInvalidTrainableSynchronizationCombination(self):
+ layer = base_layers.Layer(name='my_layer')
+
+ with self.assertRaisesRegexp(
+ ValueError, 'Synchronization value can be set to '
+ 'VariableSynchronization.ON_READ only for non-trainable variables. '
+ 'You have specified trainable=True and '
+ 'synchronization=VariableSynchronization.ON_READ.'):
+ _ = layer.add_variable(
+ 'v', [2, 2],
+ initializer=init_ops.zeros_initializer(),
+ synchronization=variable_scope.VariableSynchronization.ON_READ,
+ trainable=True)
+
def testReusePartitionedVaraiblesAndRegularizers(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
partitioner = partitioned_variables.fixed_size_partitioner(3)
@@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase):
partitioner=partitioner,
reuse=reuse):
layer = base_layers.Layer(name='my_layer')
- variable = layer.add_variable(
+ _ = layer.add_variable(
'reg_part_var', [4, 4],
initializer=init_ops.zeros_initializer(),
regularizer=regularizer)