diff options
author | Anjali Sridhar <anjalisridhar@google.com> | 2018-06-20 10:54:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 10:57:38 -0700 |
commit | 2b45f14362aaa00cf7fc640f375048bffba98655 (patch) | |
tree | e7900b52b14cbb7f058032f3b104ce368d9759d8 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 88625ad7257ecf9d33f36f8395bf00a427a8f4e3 (diff) |
Allow TowerLocalVars to be updated with the same value across all towers.
PiperOrigin-RevId: 201379124
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 5 |
1 files changed, 2 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index c1b4b870a5..dc270ac540 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -323,14 +323,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): value_destination_pairs) def _update(self, var, fn, *args, **kwargs): - # TODO(josh11b): Also support TowerLocalVariables here? If so, args and - # kwargs don't need to be mirrored. - assert isinstance(var, values.MirroredVariable) # TODO(josh11b): In eager mode, use one thread per device. + assert isinstance(var, values.DistributedVariable) updates = {} for d, v in var._index.items(): # pylint: disable=protected-access name = "update_%d" % self._device_index.get(d) with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name): + # If args and kwargs are not mirrored, the value is returned as is. updates[d] = fn(v, *values.select_device_mirrored(d, args), **values.select_device_mirrored(d, kwargs)) |