aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-28 23:28:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 23:32:54 -0700
commit13d6aec1241ddf27365dc555d997ed40d43c2e06 (patch)
tree08cb2e1f34e2767cf9db6f63d398885937ee2e9c /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent82993516eef05ba6074b1e613e45c12faa3c5793 (diff)
Use nccl if there is only one worker in multi-worker MirroredStrategies.
PiperOrigin-RevId: 210669284
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index e87b48ba41..b44edfbd27 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -564,8 +564,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cross_tower_ops is None:
if self._cluster_spec:
- self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
- self._workers, self._num_gpus)
+ # It currently cannot detect the toplogy of remote workers. So we
+ # hard-code the multi-worker all-reduce algorithm for now.
+ if len(self._workers) == 1:
+ # The default is "nccl".
+ self._cross_tower_ops = cross_tower_ops_lib.AllReduceCrossTowerOps()
+ else:
+ # The default is hierarchical reduce and broadcast.
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
else:
self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
self._devices, session_config=session_config)