aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-20 16:14:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 16:17:07 -0700
commit34a12dff9812d291dff494dae9abecc13b494b8a (patch)
tree89cd50a6d0f4f42d3cff811bcddc0ab76f1e978c /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent185b862db1cda8f99e719b4f287c6c1eba1c2f73 (diff)
Switch away from DistributionStrategy.fetch() (mostly just in tests)
so we can delete it. Frequently we can now delete the call entirely, but in other cases we switch to read_var(). This revealed some bugs also fixed in this CL: * For MirroredStrategy: fix read_var(mean_tower_local) bug. * Support get() for Mirrored values that are not MirroredVariables, and make them DistributedDelegates so we can operate on them in cross-tower mode. * Actually iterate through the available devices in MirroredStrategy.get(). With this and already-submitted 201390698, we can pass mirrored variables and other mirrored values directly to self.evaluate() in tests. PiperOrigin-RevId: 201435436
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py6
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index dc270ac540..d8668b398f 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -31,7 +31,6 @@ from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
@@ -286,8 +285,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
index = {}
- i = 0
- for m in map_over:
+ for i, m in enumerate(map_over):
d = self._devices[i % len(self._devices)]
with ops.device(d):
l = index.get(d, [])
@@ -349,7 +347,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
if isinstance(tower_local_var, values.TowerLocalVariable):
- return math_ops.add_n(self.unwrap(tower_local_var))
+ return tower_local_var._get_cross_tower() # pylint: disable=protected-access
assert isinstance(tower_local_var, values.Mirrored)
return array_ops.identity(tower_local_var.get())