aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Igor Ganichev <iga@google.com>2018-08-06 17:57:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-06 18:01:04 -0700
commit99f10ce2574995a5234409981fbf6df991dd3c7c (patch)
treec76adf92558a73be4716dbf6fc88f86c2bea0c5d /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent67284b2efd11b8c684afdc3cfa28d885efb99fa3 (diff)
Resolve distributed variables captured by defun at call time
Before this change, when was function is called in a distribution strategy context, it would capture the component variables from some device and always use these variables, even when the function is executed on a different device. This CL "reevaluates" distributed variables to get the correct variable at call time. These correct variables are then passed to the function. We don't handle distributed tensors. First, because the mechanics for handling distributed tensors are different from handling distributed variables, their support added significant complexity to already complex defuns. Second, there is no easy way for users have a function capture a distributed tensor or feed a distributed tensor explicitly. If this changes, we can support them (the code exists in this CL's history). We also don't handle distributed variables explicitly passed into the function for similar reasons. PiperOrigin-RevId: 207640908
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 0c26ae8dbc..01b456d0d4 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -434,6 +434,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return [val.get(device=d) for d in sorted(val.devices)]
return [val]
+ def value_container(self, val):
+ return values.value_container(val)
+
@property
def is_single_tower(self):
return len(self._devices) == 1