aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:22:42 -0700
commitaca93368a979419360c1fd84b53b1766b19ba81a (patch)
tree2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/python/training
parent8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff)
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/python/training')
-rw-r--r--tensorflow/python/training/distribute.py20
-rw-r--r--tensorflow/python/training/training_util.py2
2 files changed, 17 insertions, 5 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 1ac7c39872..ac92238d57 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.util import deprecation
from tensorflow.python.util import nest
@@ -723,7 +724,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value: A per-device value with one value per tower.
destinations: An optional mirrored variable, a device string,
list of device strings. The return value will be copied to all
@@ -740,7 +742,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._reduce(aggregation, value, destinations)
@@ -752,7 +755,8 @@ class DistributionStrategy(object):
Args:
aggregation: Indicates how a variable will be aggregated. Accepted values
- are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`.
+ are `tf.VariableAggregation.SUM`, `tf.VariableAggregation.MEAN`,
+ `tf.VariableAggregation.ONLY_FIRST_TOWER`.
value_destination_pairs: A sequence of (value, destinations)
pairs. See `reduce()` for a description.
@@ -763,7 +767,8 @@ class DistributionStrategy(object):
_require_cross_tower_context(self)
assert aggregation in [
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
]
return self._batch_reduce(aggregation, value_destination_pairs)
@@ -1168,9 +1173,14 @@ class _DefaultDistributionStrategy(DistributionStrategy):
# ------------------------------------------------------------------------------
-# Common operations
+# Deprecated, use v.assign_add(amount) instead. Internal API, so expect
+# it to be deleted soon.
+@deprecation.deprecated(None,
+ "Use v.assign_add(amount) instead. You may need to set "
+ "aggregation=tf.VariableAggregation.ONLY_FIRST_TOWER "
+ "when creating the variable.")
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
diff --git a/tensorflow/python/training/training_util.py b/tensorflow/python/training/training_util.py
index 2ff3eeb153..d998d6af81 100644
--- a/tensorflow/python/training/training_util.py
+++ b/tensorflow/python/training/training_util.py
@@ -129,6 +129,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])
# Create in proper graph and base name_scope.
@@ -139,6 +140,7 @@ def create_global_step(graph=None):
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
+ aggregation=variables.VariableAggregation.ONLY_FIRST_TOWER,
collections=[ops.GraphKeys.GLOBAL_VARIABLES,
ops.GraphKeys.GLOBAL_STEP])