diff options
author | Pavithra Vijay <psv@google.com> | 2018-07-09 16:58:04 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-09 17:04:47 -0700 |
commit | 855d2a723c24ce69ee993144c16506dbef12ed69 (patch) | |
tree | 8a0b8a9ce633bbd59eeef53ff5973cb2e49fb412 /tensorflow/python | |
parent | e0541b714d0df485a79ece616575ebb2a71b818c (diff) |
Add `synchronization` and `aggregation` args to the layer `add_weight()` API. These args will be used for distributed variables.
Migrate all usages of `tower_local_var_scope` to using the new args.
PiperOrigin-RevId: 203855963
Diffstat (limited to 'tensorflow/python')
-rw-r--r-- | tensorflow/python/eager/graph_callable.py | 4 | ||||
-rw-r--r-- | tensorflow/python/keras/engine/base_layer.py | 60 | ||||
-rw-r--r-- | tensorflow/python/keras/layers/normalization.py | 24 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/variable_scope_test.py | 68 | ||||
-rw-r--r-- | tensorflow/python/layers/base.py | 41 | ||||
-rw-r--r-- | tensorflow/python/layers/base_test.py | 26 | ||||
-rw-r--r-- | tensorflow/python/ops/metrics_impl.py | 20 | ||||
-rw-r--r-- | tensorflow/python/ops/variable_scope.py | 72 | ||||
-rw-r--r-- | tensorflow/python/training/distribute.py | 56 |
9 files changed, 256 insertions, 115 deletions
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py index 848adf4fd3..2c6f04d8ad 100644 --- a/tensorflow/python/eager/graph_callable.py +++ b/tensorflow/python/eager/graph_callable.py @@ -118,7 +118,7 @@ class _VariableCapturingScope(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, @@ -156,7 +156,7 @@ class _VariableCapturingScope(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, # pylint: disable=redefined-outer-name partitioner=None, diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py index 5c668345b7..26ea9cd797 100644 --- a/tensorflow/python/keras/engine/base_layer.py +++ b/tensorflow/python/keras/engine/base_layer.py @@ -460,14 +460,18 @@ class Layer(checkpointable.CheckpointableBase): """Alias for `add_weight`.""" return self.add_weight(*args, **kwargs) - def add_weight(self, name, shape, + def add_weight(self, + name, + shape, dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, constraint=None, partitioner=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, getter=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -482,10 +486,20 @@ class Layer(checkpointable.CheckpointableBase): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). partitioner: Partitioner to be passed to the `Checkpointable` API. use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. getter: Variable getter argument to be passed to the `Checkpointable` API. Returns: @@ -496,7 +510,8 @@ class Layer(checkpointable.CheckpointableBase): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. - ValueError: When giving unsupported dtype and no initializer. + ValueError: When giving unsupported dtype and no initializer or when + trainable has been set to True with synchronization set as `ON_READ`. """ if dtype is None: dtype = self.dtype or backend.floatx() @@ -505,6 +520,19 @@ class Layer(checkpointable.CheckpointableBase): regularizer = regularizers.get(regularizer) constraint = constraints.get(constraint) + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + # Initialize variable when no initializer provided if initializer is None: # If dtype is DT_FLOAT, provide a uniform unit scaling initializer @@ -532,7 +560,9 @@ class Layer(checkpointable.CheckpointableBase): constraint=constraint, trainable=trainable and self.trainable, partitioner=partitioner, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) if regularizer is not None: # TODO(fchollet): in the future, this should be handled at the @@ -1806,11 +1836,13 @@ def make_variable(name, dtype=dtypes.float32, initializer=None, partition_info=None, - trainable=True, + trainable=None, caching_device=None, validate_shape=True, constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): # pylint: disable=unused-argument """Temporary util to create a variable (relies on `variable_scope.variable`). @@ -1836,11 +1868,21 @@ def make_variable(name, or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. caching_device: Passed to `vs.variable`. validate_shape: Passed to `vs.variable`. constraint: Constraint instance (callable). use_resource: Whether to use a `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. partitioner: Not handled at this time. Returns: @@ -1872,5 +1914,7 @@ def make_variable(name, dtype=variable_dtype, validate_shape=validate_shape, constraint=constraint, - use_resource=use_resource) + use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation) return v diff --git a/tensorflow/python/keras/layers/normalization.py b/tensorflow/python/keras/layers/normalization.py index 8b894ca6b1..58c8a8a66d 100644 --- a/tensorflow/python/keras/layers/normalization.py +++ b/tensorflow/python/keras/layers/normalization.py @@ -181,12 +181,6 @@ class BatchNormalization(Layer): self.renorm_clipping = renorm_clipping self.renorm_momentum = renorm_momentum - def _add_tower_local_variable(self, *args, **kwargs): - tower_context = distribute_lib.get_tower_context() - with tower_context.tower_local_var_scope( - variable_scope.VariableAggregation.MEAN): - return self.add_weight(*args, **kwargs) - def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if not input_shape.ndims: @@ -314,19 +308,23 @@ class BatchNormalization(Layer): self._scope.set_partitioner(None) else: partitioner = None - self.moving_mean = self._add_tower_local_variable( + self.moving_mean = self.add_weight( name='moving_mean', shape=param_shape, dtype=param_dtype, initializer=self.moving_mean_initializer, - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) - self.moving_variance = self._add_tower_local_variable( + self.moving_variance = self.add_weight( name='moving_variance', shape=param_shape, dtype=param_dtype, initializer=self.moving_variance_initializer, - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) if self.renorm: # Create variables to maintain the moving mean and standard deviation. @@ -337,12 +335,14 @@ class BatchNormalization(Layer): # stack to be cleared. The nested ones use a `lambda` to set the desired # device and ignore any devices that may be set by the custom getter. def _renorm_variable(name, shape): - var = self._add_tower_local_variable( + var = self.add_weight( name=name, shape=shape, dtype=param_dtype, initializer=init_ops.zeros_initializer(), - trainable=False) + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=False, + aggregation=variable_scope.VariableAggregation.MEAN) return var with distribute_lib.get_distribution_strategy().colocate_vars_with( diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py index 054c6f9dd7..ae2a0ab29a 100644 --- a/tensorflow/python/kernel_tests/variable_scope_test.py +++ b/tensorflow/python/kernel_tests/variable_scope_test.py @@ -1054,7 +1054,7 @@ class VariableScopeTest(test.TestCase): "testGetCollection_foo/testGetCollection_a:0" ]) - def testGetTrainableVariables(self): + def testGetTrainableVariablesWithGetVariable(self): with self.test_session(): _ = variable_scope.get_variable("testGetTrainableVariables_a", []) with variable_scope.variable_scope( @@ -1062,10 +1062,72 @@ class VariableScopeTest(test.TestCase): _ = variable_scope.get_variable("testGetTrainableVariables_b", []) _ = variable_scope.get_variable( "testGetTrainableVariables_c", [], trainable=False) + + # sync `ON_READ` sets trainable=False + _ = variable_scope.get_variable( + "testGetTrainableVariables_d", [], + synchronization=variable_scope.VariableSynchronization.ON_READ) self.assertEqual( [v.name for v in scope.trainable_variables()], - ["testGetTrainableVariables_foo/" - "testGetTrainableVariables_b:0"]) + ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) + + # All other sync values sets trainable=True + _ = variable_scope.get_variable( + "testGetTrainableVariables_e", [], + synchronization=variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual([v.name for v in scope.trainable_variables()], [ + "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", + "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" + ]) + + with self.assertRaisesRegexp( + ValueError, "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."): + _ = variable_scope.get_variable( + "testGetTrainableVariables_e", [], + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) + + def testGetTrainableVariablesWithVariable(self): + with self.test_session(): + _ = variable_scope.variable(1.0, name="testGetTrainableVariables_a") + with variable_scope.variable_scope( + "testGetTrainableVariables_foo") as scope: + _ = variable_scope.variable(1.0, name="testGetTrainableVariables_b") + _ = variable_scope.variable( + 1.0, name="testGetTrainableVariables_c", trainable=False) + + # sync `ON_READ` sets trainable=False + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_d", + synchronization=variable_scope.VariableSynchronization.ON_READ) + self.assertEqual( + [v.name for v in scope.trainable_variables()], + ["testGetTrainableVariables_foo/testGetTrainableVariables_b:0"]) + + # All other sync values sets trainable=True + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_e", + synchronization=variable_scope.VariableSynchronization.ON_WRITE) + self.assertEqual([v.name for v in scope.trainable_variables()], [ + "testGetTrainableVariables_foo/testGetTrainableVariables_b:0", + "testGetTrainableVariables_foo/testGetTrainableVariables_e:0" + ]) + + with self.assertRaisesRegexp( + ValueError, "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ."): + _ = variable_scope.variable( + 1.0, + name="testGetTrainableVariables_e", + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) def testGetGlobalVariables(self): with self.test_session(): diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py index b8969a41ab..cf13b52617 100644 --- a/tensorflow/python/layers/base.py +++ b/tensorflow/python/layers/base.py @@ -152,10 +152,17 @@ class Layer(base_layer.Layer): scope, default_name=self._base_name) as captured_scope: self._scope = captured_scope - def add_weight(self, name, shape, dtype=None, - initializer=None, regularizer=None, - trainable=True, constraint=None, + def add_weight(self, + name, + shape, + dtype=None, + initializer=None, + regularizer=None, + trainable=None, + constraint=None, use_resource=None, + synchronization=vs.VariableSynchronization.AUTO, + aggregation=vs.VariableAggregation.NONE, partitioner=None): """Adds a new variable to the layer, or gets an existing one; returns it. @@ -170,9 +177,19 @@ class Layer(base_layer.Layer): or "non_trainable_variables" (e.g. BatchNorm mean, stddev). Note, if the current variable scope is marked as non-trainable then this parameter is ignored and any added variables are also - marked as non-trainable. + marked as non-trainable. `trainable` defaults to `True` unless + `synchronization` is set to `ON_READ`. constraint: constraint instance (callable). use_resource: Whether to use `ResourceVariable`. + synchronization: Indicates when a distributed a variable will be + aggregated. Accepted values are constants defined in the class + @{tf.VariableSynchronization}. By default the synchronization is set to + `AUTO` and the current `DistributionStrategy` chooses + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. + aggregation: Indicates how a distributed variable will be aggregated. + Accepted values are constants defined in the class + @{tf.VariableAggregation}. partitioner: (optional) partitioner instance (callable). If provided, when the requested variable is created it will be split into multiple partitions according to `partitioner`. In this case, @@ -190,7 +207,21 @@ class Layer(base_layer.Layer): Raises: RuntimeError: If called with partioned variable regularization and eager execution is enabled. + ValueError: When trainable has been set to True with synchronization + set as `ON_READ`. """ + if synchronization == vs.VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.') + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True def _should_add_regularizer(variable, existing_variable_set): if isinstance(variable, tf_variables.PartitionedVariable): @@ -240,6 +271,8 @@ class Layer(base_layer.Layer): constraint=constraint, partitioner=partitioner, use_resource=use_resource, + synchronization=synchronization, + aggregation=aggregation, getter=vs.get_variable) if regularizer: diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py index 298e96e711..d2443db665 100644 --- a/tensorflow/python/layers/base_test.py +++ b/tensorflow/python/layers/base_test.py @@ -90,12 +90,34 @@ class BaseLayerTest(test.TestCase): # regularizers only supported in GRAPH mode. regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 - variable = layer.add_variable( + _ = layer.add_variable( 'reg_var', [2, 2], initializer=init_ops.zeros_initializer(), regularizer=regularizer) self.assertEqual(len(layer.losses), 1) + # Test that sync `ON_READ` variables are defaulted to be non-trainable. + variable_3 = layer.add_variable( + 'sync_on_read_var', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM) + self.assertEqual(layer.non_trainable_variables, [variable_2, variable_3]) + + def testInvalidTrainableSynchronizationCombination(self): + layer = base_layers.Layer(name='my_layer') + + with self.assertRaisesRegexp( + ValueError, 'Synchronization value can be set to ' + 'VariableSynchronization.ON_READ only for non-trainable variables. ' + 'You have specified trainable=True and ' + 'synchronization=VariableSynchronization.ON_READ.'): + _ = layer.add_variable( + 'v', [2, 2], + initializer=init_ops.zeros_initializer(), + synchronization=variable_scope.VariableSynchronization.ON_READ, + trainable=True) + def testReusePartitionedVaraiblesAndRegularizers(self): regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3 partitioner = partitioned_variables.fixed_size_partitioner(3) @@ -104,7 +126,7 @@ class BaseLayerTest(test.TestCase): partitioner=partitioner, reuse=reuse): layer = base_layers.Layer(name='my_layer') - variable = layer.add_variable( + _ = layer.add_variable( 'reg_part_var', [4, 4], initializer=init_ops.zeros_initializer(), regularizer=regularizer) diff --git a/tensorflow/python/ops/metrics_impl.py b/tensorflow/python/ops/metrics_impl.py index bfd225b0d8..3aedeb6acd 100644 --- a/tensorflow/python/ops/metrics_impl.py +++ b/tensorflow/python/ops/metrics_impl.py @@ -73,16 +73,16 @@ def metric_variable(shape, dtype, validate_shape=True, name=None): A (non-trainable) variable initialized to zero, or if inside a `DistributionStrategy` scope a tower-local variable container. """ - with distribute_lib.get_tower_context().tower_local_var_scope( - variable_scope.VariableAggregation.SUM): - # Note that "tower local" implies trainable=False. - return variable_scope.variable( - lambda: array_ops.zeros(shape, dtype), - collections=[ - ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES - ], - validate_shape=validate_shape, - name=name) + # Note that synchronization "ON_READ" implies trainable=False. + return variable_scope.variable( + lambda: array_ops.zeros(shape, dtype), + collections=[ + ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES + ], + validate_shape=validate_shape, + synchronization=variable_scope.VariableSynchronization.ON_READ, + aggregation=variable_scope.VariableAggregation.SUM, + name=name) def _remove_squeezable_dimensions(predictions, labels, weights): diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py index 1e06bf07d5..77f67c18ee 100644 --- a/tensorflow/python/ops/variable_scope.py +++ b/tensorflow/python/ops/variable_scope.py @@ -255,7 +255,7 @@ class _VariableStore(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -300,6 +300,8 @@ class _VariableStore(object): forced to be False. trainable: If `True` also add the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`). + `trainable` defaults to `True` unless `synchronization` is + set to `ON_READ`. collections: List of graph collections keys to add the `Variable` to. Defaults to `[GraphKeys.GLOBAL_VARIABLES]` (see `tf.Variable`). caching_device: Optional device string or function describing where the @@ -341,7 +343,8 @@ class _VariableStore(object): aggregated. Accepted values are constants defined in the class @{tf.VariableSynchronization}. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class @{tf.VariableAggregation}. @@ -404,7 +407,7 @@ class _VariableStore(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -477,6 +480,10 @@ class _VariableStore(object): synchronization=synchronization, aggregation=aggregation) + # Set trainable value based on synchronization value. + trainable = _get_trainable_value( + synchronization=synchronization, trainable=trainable) + if custom_getter is not None: # Handle backwards compatibility with getter arguments that were added # to the API after users started writing custom getters. @@ -519,11 +526,20 @@ class _VariableStore(object): synchronization=synchronization, aggregation=aggregation) - def _get_partitioned_variable( - self, name, partitioner, shape=None, dtype=dtypes.float32, - initializer=None, regularizer=None, reuse=None, - trainable=True, collections=None, caching_device=None, - validate_shape=True, use_resource=None, constraint=None): + def _get_partitioned_variable(self, + name, + partitioner, + shape=None, + dtype=dtypes.float32, + initializer=None, + regularizer=None, + reuse=None, + trainable=None, + collections=None, + caching_device=None, + validate_shape=True, + use_resource=None, + constraint=None): """Gets or creates a sharded variable list with these parameters. The `partitioner` must be a callable that accepts a fully defined @@ -773,7 +789,7 @@ class _VariableStore(object): regularizer=None, partition_info=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, validate_shape=True, @@ -1136,7 +1152,7 @@ class VariableScope(object): initializer=None, regularizer=None, reuse=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -1207,7 +1223,7 @@ class VariableScope(object): dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -1422,7 +1438,7 @@ def get_variable(name, dtype=None, initializer=None, regularizer=None, - trainable=True, + trainable=None, collections=None, caching_device=None, partitioner=None, @@ -2334,11 +2350,28 @@ def _compute_slice_dim_and_shape(full_shape, slicing): return slice_dim, slice_shape +def _get_trainable_value(synchronization, trainable): + """Computes the trainable value based on the given arguments.""" + if synchronization == VariableSynchronization.ON_READ: + if trainable: + raise ValueError( + "Synchronization value can be set to " + "VariableSynchronization.ON_READ only for non-trainable variables. " + "You have specified trainable=True and " + "synchronization=VariableSynchronization.ON_READ.") + else: + # Set trainable to be false when variable is to be synced on read. + trainable = False + elif trainable is None: + trainable = True + return trainable + + def default_variable_creator(next_creator=None, **kwargs): """Default variable creator.""" assert next_creator is None initial_value = kwargs.get("initial_value", None) - trainable = kwargs.get("trainable", True) + trainable = kwargs.get("trainable", None) collections = kwargs.get("collections", None) validate_shape = kwargs.get("validate_shape", True) caching_device = kwargs.get("caching_device", None) @@ -2347,10 +2380,10 @@ def default_variable_creator(next_creator=None, **kwargs): constraint = kwargs.get("constraint", None) use_resource = kwargs.get("use_resource", None) - # Enforce `ON_READ` variables to be not trainable. + # Set trainable value based on synchronization value. synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO) - if synchronization == VariableSynchronization.ON_READ: - trainable = False + trainable = _get_trainable_value( + synchronization=synchronization, trainable=trainable) if use_resource is None: use_resource = get_variable_scope().use_resource @@ -2379,7 +2412,7 @@ def _make_getter(captured_getter, captured_previous): def variable(initial_value=None, - trainable=True, + trainable=None, collections=None, validate_shape=True, caching_device=None, @@ -2441,6 +2474,8 @@ def variable_creator_scope(variable_creator): trainable: If `True`, the default, also adds the variable to the graph collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as the default list of variables to use by the `Optimizer` classes. + `trainable` defaults to `True` unless `synchronization` is + set to `ON_READ`. collections: List of graph collections keys. The new variable is added to these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`. validate_shape: If `False`, allows the variable to be initialized with a @@ -2463,7 +2498,8 @@ def variable_creator_scope(variable_creator): aggregated. Accepted values are constants defined in the class @{tf.VariableSynchronization}. By default the synchronization is set to `AUTO` and the current `DistributionStrategy` chooses - when to synchronize. + when to synchronize. If `synchronization` is set to `ON_READ`, + `trainable` must not be set to `True`. aggregation: Indicates how a distributed variable will be aggregated. Accepted values are constants defined in the class @{tf.VariableAggregation}. diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index d33fd7376a..c719045c7f 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -614,48 +614,6 @@ class DistributionStrategy(object): # Note: should support "colocate_with" argument. raise NotImplementedError("must be implemented in descendants") - def tower_local_var_scope(self, aggregation): - """Inside this scope, new variables will not be mirrored. - - There will still be one component variable per tower, but there is - no requirement that they stay in sync. Instead, when saving them - or calling `read_var()`, we use the value that results when - calling `reduce()` on all the towers' variables. - - Note: tower-local implies not trainable. Instead, it is expected - that each tower will directly update (using `assign_add()` or - whatever) its local variable instance but only the aggregated - value (accessible using `read_var()`) will be exported from the - model. When it is acceptable to only aggregate on export, we - greatly reduce communication overhead by using tower-local - variables. - - Note: All component variables will be initialized to the same - value, using the initialization expression from the first tower. - The values will match even if the initialization expression uses - random numbers. - - Args: - aggregation: Indicates how a variable will be aggregated. Accepted values - are @{tf.VariableAggregation.SUM}, @{tf.VariableAggregation.MEAN}. - - Returns: - A context manager. - """ - # TODO(psv): Remove this after adding support for synchronization and - # aggregation parameters in get_variable() and mirrored strategy. - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["use_resource"] = True - - # Set synchronization to be ON_READ for tower local variables. - kwargs["synchronization"] = variable_scope.VariableSynchronization.ON_READ - kwargs["aggregation"] = aggregation - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def read_var(self, v): """Reads the value of a variable. @@ -1103,10 +1061,6 @@ class TowerContext(object): finally: _pop_per_thread_mode() - def tower_local_var_scope(self, aggregation): - """Alias for distribution_strategy.tower_local_var_scope().""" - return self._distribution_strategy.tower_local_var_scope(aggregation) - @property def is_single_tower(self): """Returns whether there is a single tower or multiple.""" @@ -1158,16 +1112,6 @@ class _DefaultDistributionStrategy(DistributionStrategy): return _CurrentDistributionContext( self, variable_scope.variable_creator_scope(creator)) - def tower_local_var_scope(self, aggregation): - """Does not set to resource variables.""" - def create_tower_local_variable(next_creator, *args, **kwargs): - _require_distribution_strategy_scope(self) - kwargs["trainable"] = False - return next_creator(*args, **kwargs) - - _require_distribution_strategy_scope(self) - return variable_scope.variable_creator_scope(create_tower_local_variable) - def colocate_vars_with(self, colocate_with_variable): """Does not require `self.scope`.""" _require_distribution_strategy_scope(self) |