aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-30 17:09:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 17:14:53 -0700
commit11f970c61bc21ae81b76cdee58871f0509ec9f0f (patch)
tree6de4bc8c2ac06737629b8aaf09cfbc20db630f32 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent9d5b6e6627b04abbfecb5c7a95576e70f29dd7cf (diff)
Add `num_gpus_per_worker` argument to MirroredStrategy.
PiperOrigin-RevId: 211008923
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index b4233a5eed..d1235b7afb 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -340,6 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus: number of GPUs. For local training, either specify `devices` or
`num_gpus`. In distributed training, this must be specified as number of
GPUs on each worker.
+ num_gpus_per_worker: number of GPUs per worker. This is the same as
+ `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
+ specified.
cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
@@ -349,6 +352,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def __init__(self,
devices=None,
num_gpus=None,
+ num_gpus_per_worker=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
@@ -356,9 +360,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
# Rememeber num GPUs which might be needed by `configure` method.
- self._num_gpus = num_gpus
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is not None:
+ self._num_gpus = num_gpus
+ else:
+ self._num_gpus = num_gpus_per_worker
- self._initialize_local(num_gpus, devices)
+ self._initialize_local(self._num_gpus, devices)
def _initialize_local(self, num_gpus, devices):
"""Initializes the object for local training."""