aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-21 11:13:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-21 11:17:14 -0700
commit86fb0cdb3b1f521496ef474e215e338de3cf696d (patch)
treec35900737354db9542bcd7399f0c936e26d5bef3 /tensorflow/contrib/distribute/python/values.py
parent780e7714d1ddc3480e64ed484df3c0cb5b665e0d (diff)
Make regroup work on tower-local variables as well.
PiperOrigin-RevId: 201554738
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py50
1 files changed, 25 insertions, 25 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index 389b01d3cd..9a48928a95 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -192,6 +192,10 @@ class DistributedVariable(DistributedDelegate):
# Child class must set self._primary_var before calling
# super(...).__init__(index).
self._common_name = self._primary_var.name.split(":")[0]
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._distributed_container = weakref.ref(self) # pylint: disable=protected-access
super(DistributedVariable, self).__init__(index)
@property
@@ -287,10 +291,6 @@ class MirroredVariable(DistributedVariable, Mirrored,
"""Holds a map from device to variables whose values are kept in sync."""
def __init__(self, index, primary_var):
- # Use a weakref to make it easy to map from the contained values
- # to the container without introducing a reference cycle.
- for v in six.itervalues(index):
- v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
self._primary_var = primary_var
super(MirroredVariable, self).__init__(index)
@@ -498,40 +498,40 @@ def regroup(per_device, wrap_class=PerDevice):
same_id = False
break
# Consider three cases where same_id is true:
- # * If v0 is a MirroredVariable (and same_id means it is the same
- # across all devices), we want to return it. We check
- # MirroredVariable specifically since it can look like it
- # has a _mirrored_container member since its members do.
- # * If v0 is a member of a mirrored variable, in which case
- # hasattr(v0, "_mirrored_container") is true, we want to
- # return the MirroredVariable that contains it using the
- # _mirrored_container logic below. This case can trigger
+ # * If v0 is a DistributedVariable (a MirroredVariable or
+ # TowerLocalVariable, and same_id means it is the same across all
+ # devices), we want to return it. We check DistributedVariable
+ # specifically since it can look like it has a
+ # _distributed_container member since its members do.
+ # * If v0 is a member of a distributed variable, in which case
+ # hasattr(v0, "_distributed_container") is true, we want to
+ # return the DistributedVariable that contains it using the
+ # _distributed_container logic below. This case can trigger
# same_id when there is only one device.
# * In any other situation, same_id means we return v0.
- if same_id and (isinstance(v0, MirroredVariable) or
- not hasattr(v0, "_mirrored_container")):
+ if same_id and (isinstance(v0, DistributedVariable) or
+ not hasattr(v0, "_distributed_container")):
return v0
# Detect the case where each device has a parallel component of the
- # same MirroredVariable. In this case we want to return the
- # containing MirroredVariable, after a bunch of sanity checking.
- # In particular, each component should have the same container,
- # and the devices of the variables should match the keys of the
- # per-device dictionary.
- # TODO(josh11b): Do we need similar logic for TowerLocalVariables?
- if hasattr(v0, "_mirrored_container"):
+ # same MirroredVariable (or TowerLocalVariable). In this case we
+ # want to return the containing MirroredVariable, after a bunch of
+ # sanity checking. In particular, each component should have the
+ # same container, and the devices of the variables should match the
+ # keys of the per-device dictionary.
+ if hasattr(v0, "_distributed_container"):
# pylint: disable=protected-access
assert not isinstance(v0, MirroredVariable), (
"ids = %s, items = %s" % ([id(v[1]) for v in items], items))
assert _devices_match(v0.device, items[0][0]), (
"v0.device = %s, items = %s" % (v0.device, items))
- mirrored_container = v0._mirrored_container()
- assert mirrored_container is not None
+ distributed_container = v0._distributed_container()
+ assert distributed_container is not None
for d, v in items[1:]:
assert _devices_match(v.device, d), (
"v.device = %s, d = %s, items = %s" % (v.device, d, items))
- assert mirrored_container is v._mirrored_container()
- return mirrored_container
+ assert distributed_container is v._distributed_container()
+ return distributed_container
# pylint: enable=protected-access
return wrap_class(per_device)