diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-06-26 11:25:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-26 11:28:45 -0700 |
commit | bfda539bef38845809e3b0c5930458dc500d505d (patch) | |
tree | ca88e6879357c02b08263cf81e197a8dc47efae1 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | d10213099df42d7138dd7479264e4c987a3d870f (diff) |
Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.
PiperOrigin-RevId: 202162272
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 98fea76b3d..d269bed1e5 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -309,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return self._cross_tower_ops def _reduce(self, method_string, value, destinations): - if len(self._devices) == 1 and not isinstance(value, values.PerDevice): - value = values.PerDevice({self._devices[0]: value}) - assert isinstance(value, values.PerDevice) + assert not isinstance(value, values.Mirrored) + if not isinstance(value, values.PerDevice): + if value == 0: + return 0 + if method_string == "mean": + return self._broadcast(value, destinations) + + cross_tower_ops_lib.validate_destinations(destinations) + if len(self._devices) == 1: + if destinations: + # TODO(anjalisridhar): Moves these methods to a device utility file? + devices = cross_tower_ops_lib.get_devices_from(destinations) + if len(devices) == 1: + with ops.device(devices[0]): + return array_ops.identity(value) + else: + value_updates = {} + for d in devices: + with ops.device(d): + value_updates[d] = array_ops.identity(value) + return values.Mirrored(value_updates) + raise ValueError("A non PerDevice value cannot be reduced with the given " + "method_string.") return self._get_cross_tower_ops().reduce( method_string, value, destinations=destinations) |