aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-29 10:17:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-29 10:22:42 -0700
commitaca93368a979419360c1fd84b53b1766b19ba81a (patch)
tree2312ef53a30251ec2f5538d43ba066550679f6d9 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent8a22fa7037332fc6066459ce8c6fabcd77c6ece4 (diff)
Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global
step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py31
1 files changed, 21 insertions, 10 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index b44edfbd27..b4233a5eed 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -65,7 +65,7 @@ class _RequestedStop(Exception):
pass
-# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# _call_for_each_tower and _reduce_non_distributed_value are not members of
# MirroredStrategy so that they are generally not allowed to use anything
# specific to MirroredStrategy and thus can be shared with other distribution
# strategies.
@@ -197,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and equal to 0.
if value == 0:
return 0
- # If the aggregation type is MEAN, then this essentially means that the same
- # value should be on all destinations.
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return distribution.broadcast(value, destinations)
+ # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this
+ # essentially means that the same value should be on all destinations.
+ if aggregation in (
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER):
+ return value
cross_tower_ops_lib.validate_destinations(destinations)
# We do not support an aggregation type of SUM if the value is the same across
@@ -208,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
# and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
- raise ValueError("A non-DistributedValues value cannot be reduced with the "
- "given aggregation.")
+ raise ValueError("A non-DistributedValues value %s cannot be reduced with "
+ "the given aggregation %s." % (value, aggregation))
# TODO(anjalisridhar): Moves these methods to a device utility file?
devices = cross_tower_ops_lib.get_devices_from(destinations)
if len(devices) == 1:
@@ -254,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
# Get aggregation value
aggregation = kwargs.pop("aggregation",
variable_scope.VariableAggregation.NONE)
- if aggregation not in [
+ if aggregation not in (
variable_scope.VariableAggregation.NONE,
variable_scope.VariableAggregation.SUM,
- variable_scope.VariableAggregation.MEAN
- ]:
+ variable_scope.VariableAggregation.MEAN,
+ variable_scope.VariableAggregation.ONLY_FIRST_TOWER
+ ):
raise ValueError("Invalid variable aggregation mode: " + aggregation +
" for variable: " + kwargs["name"])
@@ -591,10 +594,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ value = value.get(self._devices[0])
+ if isinstance(value, (int, float)):
+ return value
+ return self.broadcast(value, destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
def _batch_reduce(self, aggregation, value_destination_pairs):
+ if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER:
+ return [self.broadcast(v.get(self._devices[0]), d)
+ for v, d in value_destination_pairs]
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)