aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/values.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-06 17:57:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 18:01:04 -0700
commit99f10ce2574995a5234409981fbf6df991dd3c7c (patch)
treec76adf92558a73be4716dbf6fc88f86c2bea0c5d /tensorflow/contrib/distribute/python/values.py
parent67284b2efd11b8c684afdc3cfa28d885efb99fa3 (diff)
Resolve distributed variables captured by defun at call time
Before this change, when was function is called in a distribution strategy context, it would capture the component variables from some device and always use these variables, even when the function is executed on a different device. This CL "reevaluates" distributed variables to get the correct variable at call time. These correct variables are then passed to the function. We don't handle distributed tensors. First, because the mechanics for handling distributed tensors are different from handling distributed variables, their support added significant complexity to already complex defuns. Second, there is no easy way for users have a function capture a distributed tensor or feed a distributed tensor explicitly. If this changes, we can support them (the code exists in this CL's history). We also don't handle distributed variables explicitly passed into the function for similar reasons. PiperOrigin-RevId: 207640908
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r--tensorflow/contrib/distribute/python/values.py24
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index f4da91a8ac..6f34dd4746 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -995,3 +995,27 @@ class MultiStepContext(object):
assert o.dtype == i.dtype, (
"Dtype {} of left {} doesn't match dtype {} of right {}.".
format(o.dtype, o, i.dtype, i))
+
+
+def value_container(val):
+ """Returns the container that this per-device `value` belongs to.
+
+ Args:
+ val: A value returned by `call_for_each_tower()` or a variable
+ created in `scope()`.
+
+ Returns:
+ A container that `value` belongs to.
+ If value does not belong to any container (including the case of
+ container having been destroyed), returns the value itself.
+ """
+ # pylint: disable=protected-access
+ if (hasattr(val, "_distributed_container") and
+ # DistributedVariable has _distributed_container defined
+ # but we don't want to return it.
+ not isinstance(val, DistributedVariable)):
+ container = val._distributed_container()
+ # pylint: disable=protected-access
+ if container is not None:
+ return container
+ return val