aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Anjali Sridhar <anjalisridhar@google.com>2018-06-20 10:54:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 10:57:38 -0700
commit2b45f14362aaa00cf7fc640f375048bffba98655 (patch)
treee7900b52b14cbb7f058032f3b104ce368d9759d8 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent88625ad7257ecf9d33f36f8395bf00a427a8f4e3 (diff)
Allow TowerLocalVars to be updated with the same value across all towers.
PiperOrigin-RevId: 201379124
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py5
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index c1b4b870a5..dc270ac540 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -323,14 +323,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
value_destination_pairs)
def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Also support TowerLocalVariables here? If so, args and
- # kwargs don't need to be mirrored.
- assert isinstance(var, values.MirroredVariable)
# TODO(josh11b): In eager mode, use one thread per device.
+ assert isinstance(var, values.DistributedVariable)
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))