aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-01 16:31:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-01 16:35:32 -0700
commitbb1f9e1a57c8bc18325b3c86298be96e6647a0a3 (patch)
tree66809f6708a49a5b8d408ef0b349fe837d905cda /tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
parent49bbfec04b729960999ef054e3acab719631b101 (diff)
Change semantics of DistributionStrategy.update() to make sure the
output depends on the updates across all mirrors. Before this change, update() would return a Mirrored value that where each component was an update to a single mirror. This caused a problem since for reading purposes other DistributionStrategy methods would consider it okay to read any single component, and so if you for example did something like session.run(strategy.update(...)) it would only perform the update on one replica. The fix is to have the output be a Mirrored value that is actually the identity operation returning the output on that device, but that has a control dependency making sure that the update actually happens on all the replicas. This fix was already present in MirroredVariable._assign_func, this CL moves the fix into update() and generalizes it to multiple return values. To disable this new grouping behavior, you may now pass "grouped=False" to update(). For example, some callers (like Optimizer) are performing a lot of updates and they prefer to group all of them together at once for performance reasons. In this case, we still want to make sure the caller executes the update on all replicas, so we return an unwrapped value instead of a Mirrored value. This has the happy side effect of removing a bunch of unwrap calls in client code, since unwrapping was the only safe way to use the Mirrored value we used to return. PiperOrigin-RevId: 215301909
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index f51e543624..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -826,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())