diff options
author | Igor Ganichev <iga@google.com> | 2018-08-06 17:57:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 18:01:04 -0700 |
commit | 99f10ce2574995a5234409981fbf6df991dd3c7c (patch) | |
tree | c76adf92558a73be4716dbf6fc88f86c2bea0c5d /tensorflow/contrib/distribute/python/values.py | |
parent | 67284b2efd11b8c684afdc3cfa28d885efb99fa3 (diff) |
Resolve distributed variables captured by defun at call time
Before this change, when was function is called in a distribution
strategy context, it would capture the component variables from some
device and always use these variables, even when the function is
executed on a different device.
This CL "reevaluates" distributed variables to get the correct variable
at call time. These correct variables are then passed to the function.
We don't handle distributed tensors. First, because the mechanics for handling
distributed tensors are different from handling distributed variables,
their support added significant complexity to already complex defuns.
Second, there is no easy way for users have a function capture a distributed
tensor or feed a distributed tensor explicitly. If this changes, we can
support them (the code exists in this CL's history).
We also don't handle distributed variables explicitly passed into the
function for similar reasons.
PiperOrigin-RevId: 207640908
Diffstat (limited to 'tensorflow/contrib/distribute/python/values.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/values.py | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index f4da91a8ac..6f34dd4746 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -995,3 +995,27 @@ class MultiStepContext(object): assert o.dtype == i.dtype, ( "Dtype {} of left {} doesn't match dtype {} of right {}.". format(o.dtype, o, i.dtype, i)) + + +def value_container(val): + """Returns the container that this per-device `value` belongs to. + + Args: + val: A value returned by `call_for_each_tower()` or a variable + created in `scope()`. + + Returns: + A container that `value` belongs to. + If value does not belong to any container (including the case of + container having been destroyed), returns the value itself. + """ + # pylint: disable=protected-access + if (hasattr(val, "_distributed_container") and + # DistributedVariable has _distributed_container defined + # but we don't want to return it. + not isinstance(val, DistributedVariable)): + container = val._distributed_container() + # pylint: disable=protected-access + if container is not None: + return container + return val |