diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-08-29 10:17:53 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-29 10:22:42 -0700 |
commit | aca93368a979419360c1fd84b53b1766b19ba81a (patch) | |
tree | 2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/python/training | |
parent | 8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff) |
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var()
function and just use a standard assign_add().
PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r-- | tensorflow/python/training/distribute.py | 20 | ||||
-rw-r--r-- | tensorflow/python/training/training_util.py | 2 |
2 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 1ac7c39872..ac92238d57 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -723,7 +724,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -740,7 +742,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._reduce(aggregation, value, destinations) @@ -752,7 +755,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -763,7 +767,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._batch_reduce(aggregation, value_destination_pairs) @@ -1168,9 +1173,14 @@ class _DefaultDistributionStrategy(DistributionStrategy): # ------------------------------------------------------------------------------ -# Common operations +# Deprecated, use v.assign_add(amount) instead. Internal API, so expect +# it to be deleted soon. +@deprecation.deprecated(None, + "Use v.assign_add(amount) instead. You may need to set " + "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER " + "when creating the variable.") def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py index 2ff3eeb153..d998d6af81 100644 --- a/tensorflow/python/training/training_util.py +++ b/tensorflow/python/training/training_util.py @@ -129,6 +129,7 @@ def create_global_step(graph=None): dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) # Create in proper graph and base name_scope. @@ -139,6 +140,7 @@ def create_global_step(graph=None): dtype=dtypes.int64, initializer=init_ops.zeros_initializer(), trainable=False, + aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER, collections=[ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.GLOBAL_STEP]) |