From aca93368a979419360c1fd84b53b1766b19ba81a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 29 Aug 2018 10:17:53 -0700 Subject: Add new aggregation mode "ONLY_FIRST_TOWER" and use it for the global step counter. This allows us to get rid of the increment_var() function and just use a standard assign_add(). PiperOrigin-RevId: 210743165 --- .../contrib/distribute/python/mirrored_strategy.py | 31 +++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py') diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index b44edfbd27..b4233a5eed 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -65,7 +65,7 @@ class _RequestedStop(Exception): pass -# Make _call_for_each_tower and _reduce_non_distributed_value not members of +# _call_for_each_tower and _reduce_non_distributed_value are not members of # MirroredStrategy so that they are generally not allowed to use anything # specific to MirroredStrategy and thus can be shared with other distribution # strategies. @@ -197,10 +197,12 @@ def _reduce_non_distributed_value(distribution, aggregation, value, # and equal to 0. if value == 0: return 0 - # If the aggregation type is MEAN, then this essentially means that the same - # value should be on all destinations. - if aggregation == variable_scope.VariableAggregation.MEAN: - return distribution.broadcast(value, destinations) + # If the aggregation type is MEAN or ONLY_FIRST_TOWER, then this + # essentially means that the same value should be on all destinations. + if aggregation in ( + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER): + return value cross_tower_ops_lib.validate_destinations(destinations) # We do not support an aggregation type of SUM if the value is the same across @@ -208,8 +210,8 @@ def _reduce_non_distributed_value(distribution, aggregation, value, # and summing up identical values across towers is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): - raise ValueError("A non-DistributedValues value cannot be reduced with the " - "given aggregation.") + raise ValueError("A non-DistributedValues value %s cannot be reduced with " + "the given aggregation %s." % (value, aggregation)) # TODO(anjalisridhar): Moves these methods to a device utility file? devices = cross_tower_ops_lib.get_devices_from(destinations) if len(devices) == 1: @@ -254,11 +256,12 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs): # Get aggregation value aggregation = kwargs.pop("aggregation", variable_scope.VariableAggregation.NONE) - if aggregation not in [ + if aggregation not in ( variable_scope.VariableAggregation.NONE, variable_scope.VariableAggregation.SUM, - variable_scope.VariableAggregation.MEAN - ]: + variable_scope.VariableAggregation.MEAN, + variable_scope.VariableAggregation.ONLY_FIRST_TOWER + ): raise ValueError("Invalid variable aggregation mode: " + aggregation + " for variable: " + kwargs["name"]) @@ -591,10 +594,18 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): # which case `value` would be a single value or value could be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) + if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: + value = value.get(self._devices[0]) + if isinstance(value, (int, float)): + return value + return self.broadcast(value, destinations) return self._get_cross_tower_ops().reduce( aggregation, value, destinations=destinations) def _batch_reduce(self, aggregation, value_destination_pairs): + if aggregation == variable_scope.VariableAggregation.ONLY_FIRST_TOWER: + return [self.broadcast(v.get(self._devices[0]), d) + for v, d in value_destination_pairs] return self._get_cross_tower_ops().batch_reduce(aggregation, value_destination_pairs) -- cgit v1.2.3