aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-08-06 19:19:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 19:23:02 -0700
commit576e0000d8aa227163ff8df9d2b9f2e656191c76 (patch)
treee7de2c4f87d1553a7b03417c9b48cdf7bc8baa8c /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentc26c9e6caa3f1e3e5027031db61ebde9bc3ba706 (diff)
Add comments to MirroredStrategy's reduce function. Also add more unit tests.
PiperOrigin-RevId: 207649000
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 01b456d0d4..c5d6e978e7 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -186,12 +186,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value,
raise ValueError("You are passing a `DistributedValue` to "
"`_reduce_non_distributed_value`, which is not allowed.")
+ # If the same value is present on all towers then the PerDevice value will
+ # be a single value. We also handle the case when `value` is a single value
+ # and equal to 0.
if value == 0:
return 0
+ # If the aggregation type is MEAN, then this essentially means that the same
+ # value should be on all destinations.
if aggregation == variable_scope.VariableAggregation.MEAN:
return distribution.broadcast(value, destinations)
cross_tower_ops_lib.validate_destinations(destinations)
+ # We do not support an aggregation type of SUM if the value is the same across
+ # all towers. We call this as part of assign functions for MirroredVariables
+ # and summing up identical values across towers is not clearly defined.
if (len(distribution.worker_devices) != 1 or
not cross_tower_ops_lib.check_destinations(destinations)):
raise ValueError("A non-DistributedValues value cannot be reduced with the "
@@ -386,6 +394,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
if not isinstance(value, values.DistributedValues):
+ # This function handles reducing values that are not PerDevice or Mirrored
+ # values. For example, the same value could be present on all towers in
+ # which case `value` would be a single value or value could be 0.
return _reduce_non_distributed_value(self, aggregation, value,
destinations)
return self._get_cross_tower_ops().reduce(