aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-06-26 11:25:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-26 11:28:45 -0700
commitbfda539bef38845809e3b0c5930458dc500d505d (patch)
treeca88e6879357c02b08263cf81e197a8dc47efae1 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parentd10213099df42d7138dd7479264e4c987a3d870f (diff)
Enable assign, assign_add and assign_sub to be called on Mirrored Variables in cross tower and tower context.
PiperOrigin-RevId: 202162272
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 98fea76b3d..d269bed1e5 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -309,9 +309,29 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._cross_tower_ops
def _reduce(self, method_string, value, destinations):
- if len(self._devices) == 1 and not isinstance(value, values.PerDevice):
- value = values.PerDevice({self._devices[0]: value})
- assert isinstance(value, values.PerDevice)
+ assert not isinstance(value, values.Mirrored)
+ if not isinstance(value, values.PerDevice):
+ if value == 0:
+ return 0
+ if method_string == "mean":
+ return self._broadcast(value, destinations)
+
+ cross_tower_ops_lib.validate_destinations(destinations)
+ if len(self._devices) == 1:
+ if destinations:
+ # TODO(anjalisridhar): Moves these methods to a device utility file?
+ devices = cross_tower_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ with ops.device(devices[0]):
+ return array_ops.identity(value)
+ else:
+ value_updates = {}
+ for d in devices:
+ with ops.device(d):
+ value_updates[d] = array_ops.identity(value)
+ return values.Mirrored(value_updates)
+ raise ValueError("A non PerDevice value cannot be reduced with the given "
+ "method_string.")
return self._get_cross_tower_ops().reduce(
method_string, value, destinations=destinations)