diff options
author | Igor Ganichev <iga@google.com> | 2018-08-06 17:57:09 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-06 18:01:04 -0700 |
commit | 99f10ce2574995a5234409981fbf6df991dd3c7c (patch) | |
tree | c76adf92558a73be4716dbf6fc88f86c2bea0c5d /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 67284b2efd11b8c684afdc3cfa28d885efb99fa3 (diff) |
Resolve distributed variables captured by defun at call time
Before this change, when was function is called in a distribution
strategy context, it would capture the component variables from some
device and always use these variables, even when the function is
executed on a different device.
This CL "reevaluates" distributed variables to get the correct variable
at call time. These correct variables are then passed to the function.
We don't handle distributed tensors. First, because the mechanics for handling
distributed tensors are different from handling distributed variables,
their support added significant complexity to already complex defuns.
Second, there is no easy way for users have a function capture a distributed
tensor or feed a distributed tensor explicitly. If this changes, we can
support them (the code exists in this CL's history).
We also don't handle distributed variables explicitly passed into the
function for similar reasons.
PiperOrigin-RevId: 207640908
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 3 |
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index 0c26ae8dbc..01b456d0d4 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -434,6 +434,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): return [val.get(device=d) for d in sorted(val.devices)] return [val] + def value_container(self, val): + return values.value_container(val) + @property def is_single_tower(self): return len(self._devices) == 1 |