diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-20 16:14:18 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-20 16:17:07 -0700 |
commit | 34a12dff9812d291dff494dae9abecc13b494b8a (patch) | |
tree | 89cd50a6d0f4f42d3cff811bcddc0ab76f1e978c /tensorflow/contrib/distribute/python/values.py | |
parent | 185b862db1cda8f99e719b4f287c6c1eba1c2f73 (diff) |
Switch away from DistributionStrategy.fetch() (mostly just in tests)
so we can delete it. Frequently we can now delete the call entirely,
but in other cases we switch to read_var().
This revealed some bugs also fixed in this CL:
* For MirroredStrategy: fix read_var(mean_tower_local) bug.
* Support get() for Mirrored values that are not MirroredVariables,
and make them DistributedDelegates so we can operate on them in
cross-tower mode.
* Actually iterate through the available devices in MirroredStrategy.get().
With this and already-submitted 201390698, we can pass mirrored
variables and other mirrored values directly to self.evaluate() in
tests.
PiperOrigin-RevId: 201435436
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index aca544b7e7..72def62c79 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -43,7 +43,7 @@ from tensorflow.python.util import nest # pylint: disable=line-too-long -# TODO(josh11b): Should device values be strings or DeviceSpec objects +# TODO(josh11b): Should device values be strings or DeviceSpec objects? # Not sure DeviceSpec objects are usable as a dict key. class DistributedValues(object): """Holds a map from device to values. Either PerDevice or Mirrored.""" @@ -163,9 +163,16 @@ class PerDevice(DistributedValues): pass -class Mirrored(DistributedValues): +# Note that unlike PerDevice, Mirrored values inherit from +# DistributedDelegate and so can be used directly in cross-tower mode. +class Mirrored(DistributedDelegate): """Holds a map from device to values which are kept in sync.""" - pass + + def _get_cross_tower(self): + device = device_util.canonicalize(device_util.current()) + if device in self._index: + return self._index[device] + return list(self._index.values())[0] def _assign_on_device(device, variable, tensor): @@ -353,7 +360,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject): # We use a callable so that we don't have to evaluate this expression # in the case where we are trying to restore instead of save. def tensor(): - return distribute_lib.get_distribution_strategy().fetch( + return distribute_lib.get_distribution_strategy().read_var( tower_local_variable) spec = saver.BaseSaverBuilder.SaveSpec( tensor=tensor, |