diff options
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r-- | tensorflow/python/training/distribute.py | 21 |
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 20e031569b..ac92238d57 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util from tensorflow.python.training import distribution_strategy_context +from tensorflow.python.util import deprecation from tensorflow.python.util import nest @@ -248,6 +249,7 @@ class DistributionStrategy(object): devices. We have then a few approaches we want to support: + * Code written (as if) with no knowledge of class `DistributionStrategy`. This code should work as before, even if some of the layers, etc. used by that code are written to be distribution-aware. This is done @@ -722,7 +724,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value: A per-device value with one value per tower. destinations: An optional mirrored variable, a device string, list of device strings. The return value will be copied to all @@ -739,7 +742,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._reduce(aggregation, value, destinations) @@ -751,7 +755,8 @@ class DistributionStrategy(object): Args: aggregation: Indicates how a variable will be aggregated. Accepted values - are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`. + are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`, + `tf.VariableAggregation.ONLY_FIRST_TOWER`. value_destination_pairs: A sequence of (value, destinations) pairs. See `reduce()` for a description. @@ -762,7 +767,8 @@ class DistributionStrategy(object): _require_cross_tower_context(self) assert aggregation in [ variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER ] return self._batch_reduce(aggregation, value_destination_pairs) @@ -1167,9 +1173,14 @@ class _DefaultDistributionStrategy(DistributionStrategy): # ------------------------------------------------------------------------------ -# Common operations +# Deprecated, use v.assign_add(amount) instead. Internal API, so expect +# it to be deleted soon. +@deprecation.deprecated(None, + "Use v.assign_add(amount) instead. You may need to set " + "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER " + "when creating the variable.") def increment_var(v, amount=1): """`v += amount`, distributed-aware version.""" def update(vu): |