diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-20 16:14:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 16:17:07 -0700 |
commit | 34a12dff9812d291dff494dae9abecc13b494b8a (patch) | |
tree | 89cd50a6d0f4f42d3cff811bcddc0ab76f1e978c /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 185b862db1cda8f99e719b4f287c6c1eba1c2f73 (diff) |
Switch away from DistributionStrategy.fetch() (mostly just in tests)
so we can delete it. Frequently we can now delete the call entirely,
but in other cases we switch to read_var().
This revealed some bugs also fixed in this CL:
* For MirroredStrategy: fix read_var(mean_tower_local) bug.
* Support get() for Mirrored values that are not MirroredVariables,
and make them DistributedDelegates so we can operate on them in
cross-tower mode.
* Actually iterate through the available devices in MirroredStrategy.get().
With this and already-submitted 201390698, we can pass mirrored
variables and other mirrored values directly to self.evaluate() in
tests.
PiperOrigin-RevId: 201435436
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index dc270ac540..d8668b398f 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -31,7 +31,6 @@ from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import coordinator from tensorflow.python.training import device_util @@ -286,8 +285,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def map(self, map_over, fn, *args, **kwargs): # TODO(josh11b): In eager mode, use one thread per device. index = {} - i = 0 - for m in map_over: + for i, m in enumerate(map_over): d = self._devices[i % len(self._devices)] with ops.device(d): l = index.get(d, []) @@ -349,7 +347,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def read_var(self, tower_local_var): """Read the aggregate value of a tower-local variable.""" if isinstance(tower_local_var, values.TowerLocalVariable): - return math_ops.add_n(self.unwrap(tower_local_var)) + return tower_local_var._get_cross_tower() # pylint: disable=protected-access assert isinstance(tower_local_var, values.Mirrored) return array_ops.identity(tower_local_var.get()) |