diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-08-06 19:19:14 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 19:23:02 -0700 |
commit | 576e0000d8aa227163ff8df9d2b9f2e656191c76 (patch) | |
tree | e7de2c4f87d1553a7b03417c9b48cdf7bc8baa8c /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | c26c9e6caa3f1e3e5027031db61ebde9bc3ba706 (diff) |
Add comments to MirroredStrategy's reduce function. Also add more unit tests.
PiperOrigin-RevId: 207649000
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 11 |
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 01b456d0d4..c5d6e978e7 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -186,12 +186,20 @@ def _reduce_non_distributed_value(distribution, aggregation, value, raise ValueError("You are passing a `DistributedValue` to " "`_reduce_non_distributed_value`, which is not allowed.") + # If the same value is present on all towers then the PerDevice value will + # be a single value. We also handle the case when `value` is a single value + # and equal to 0. if value == 0: return 0 + # If the aggregation type is MEAN, then this essentially means that the same + # value should be on all destinations. if aggregation == variable_scope.VariableAggregation.MEAN: return distribution.broadcast(value, destinations) cross_tower_ops_lib.validate_destinations(destinations) + # We do not support an aggregation type of SUM if the value is the same across + # all towers. We call this as part of assign functions for MirroredVariables + # and summing up identical values across towers is not clearly defined. if (len(distribution.worker_devices) != 1 or not cross_tower_ops_lib.check_destinations(destinations)): raise ValueError("A non-DistributedValues value cannot be reduced with the " @@ -386,6 +394,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def _reduce(self, aggregation, value, destinations): assert not isinstance(value, values.Mirrored) if not isinstance(value, values.DistributedValues): + # This function handles reducing values that are not PerDevice or Mirrored + # values. For example, the same value could be present on all towers in + # which case `value` would be a single value or value could be 0. return _reduce_non_distributed_value(self, aggregation, value, destinations) return self._get_cross_tower_ops().reduce( |