diff options
author | Pavithra Vijay <psv@google.com> | 2018-07-09 16:58:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 17:04:47 -0700 |
commit | 855d2a723c24ce69ee993144c16506dbef12ed69 (patch) | |
tree | 8a0b8a9ce633bbd59eeef53ff5973cb2e49fb412 /tensorflow/python/layers | |
parent | e0541b714d0df485a79ece616575ebb2a71b818c (diff) |
Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
Migrate all usages of `tower_local_var_scope` to using the new args.
PiperOrigin-RevId: 203855963
Diffstat (limited to 'tensorflow/python/layers')
-rw-r--r-- | tensorflow/python/layers/base.py | 41 | ||||
-rw-r--r-- | tensorflow/python/layers/base_test.py | 26 |
2 files changed, 61 insertions, 6 deletions
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index b8969a41ab..cf13b52617 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -152,10 +152,17 @@ class Layer(base_layer.Layer): scope, default_name=self._base_name) as captured_scope: self._scope = captured_scope - def add_weight(self, name, shape, dtype=None, - initializer=None, regularizer=None, - trainable=True, constraint=None, + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=None, + constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -170,9 +177,19 @@ class Layer(base_layer.Layer): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. partitioner: (optional) partitioner instance (callable). If provided, when the requested variable is created it will be split into multiple partitions according to `partitioner`. In this case, @@ -190,7 +207,21 @@ class Layer(base_layer.Layer): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. + ValueError: When trainable has been set to True with synchronization + set as `ON_READ`. """ + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): @@ -240,6 +271,8 @@ class Layer(base_layer.Layer): constraint=constraint, partitioner=partitioner, use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, getter=vs.get_variable) if regularizer: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 298e96e711..d2443db665 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase): # regularizers only supported in GRAPH mode. regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 - variable = layer.add_variable( + _ = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + # Test that sync `ON_READ` variables are defaulted to be non-trainable. + variable_3 = layer.add_variable( + 'sync_on_read_var', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3]) + + def testInvalidTrainableSynchronizationCombination(self): + layer = base_layers.Layer(name='my_layer') + + with self.assertRaisesRegexp( + ValueError, 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.'): + _ = layer.add_variable( + 'v', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) + def testReusePartitionedVaraiblesAndRegularizers(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 partitioner = partitioned_variables.fixed_size_partitioner(3) @@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase): partitioner=partitioner, reuse=reuse): layer = base_layers.Layer(name='my_layer') - variable = layer.add_variable( + _ = layer.add_variable( 'reg_part_var', [4, 4], initializer=init_ops.zeros_initializer(), regularizer=regularizer) |