aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r--tensorflow/python/training/distribute.py21
1 files changed, 16 insertions, 5 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 20e031569b..ac92238d57 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -248,6 +249,7 @@ class DistributionStrategy(object):
devices.
We have then a few approaches we want to support:
+
* Code written (as if) with no knowledge of class `DistributionStrategy`.
This code should work as before, even if some of the layers, etc.
used by that code are written to be distribution-aware. This is done
@@ -722,7 +724,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -739,7 +742,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._reduce(aggregation, value, destinations)
@@ -751,7 +755,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -762,7 +767,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -1167,9 +1173,14 @@ class _DefaultDistributionStrategy(DistributionStrategy):
# ------------------------------------------------------------------------------
-# Common operations
+# Deprecated, use v.assign_add(amount) instead. Internal API, so expect
+# it to be deleted soon.
+@deprecation.deprecated(None,
+ "Use v.assign_add(amount) instead. You may need to set "
+ "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER "
+ "when creating the variable.")
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):