aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-20 16:14:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-20 16:17:07 -0700
commit34a12dff9812d291dff494dae9abecc13b494b8a (patch)
tree89cd50a6d0f4f42d3cff811bcddc0ab76f1e978c /tensorflow/contrib/distribute/python/values.py
parent185b862db1cda8f99e719b4f287c6c1eba1c2f73 (diff)
Switch away from DistributionStrategy.fetch() (mostly just in tests)
so we can delete it. Frequently we can now delete the call entirely, but in other cases we switch to read_var(). This revealed some bugs also fixed in this CL: * For MirroredStrategy: fix read_var(mean_tower_local) bug. * Support get() for Mirrored values that are not MirroredVariables, and make them DistributedDelegates so we can operate on them in cross-tower mode. * Actually iterate through the available devices in MirroredStrategy.get(). With this and already-submitted 201390698, we can pass mirrored variables and other mirrored values directly to self.evaluate() in tests. PiperOrigin-RevId: 201435436
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py15
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index aca544b7e7..72def62c79 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -43,7 +43,7 @@ from tensorflow.python.util import nest
# pylint: disable=line-too-long
-# TODO(josh11b): Should device values be strings or DeviceSpec objects
+# TODO(josh11b): Should device values be strings or DeviceSpec objects?
# Not sure DeviceSpec objects are usable as a dict key.
class DistributedValues(object):
"""Holds a map from device to values. Either PerDevice or Mirrored."""
@@ -163,9 +163,16 @@ class PerDevice(DistributedValues):
pass
-class Mirrored(DistributedValues):
+# Note that unlike PerDevice, Mirrored values inherit from
+# DistributedDelegate and so can be used directly in cross-tower mode.
+class Mirrored(DistributedDelegate):
"""Holds a map from device to values which are kept in sync."""
- pass
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return list(self._index.values())[0]
def _assign_on_device(device, variable, tensor):
@@ -353,7 +360,7 @@ class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
# We use a callable so that we don't have to evaluate this expression
# in the case where we are trying to restore instead of save.
def tensor():
- return distribute_lib.get_distribution_strategy().fetch(
+ return distribute_lib.get_distribution_strategy().read_var(
tower_local_variable)
spec = saver.BaseSaverBuilder.SaveSpec(
tensor=tensor,