diff options
author | 2018-08-30 17:09:50 -0700 | |
---|---|---|
committer | 2018-08-30 17:14:53 -0700 | |
commit | 11f970c61bc21ae81b76cdee58871f0509ec9f0f (patch) | |
tree | 6de4bc8c2ac06737629b8aaf09cfbc20db630f32 /tensorflow/contrib/distribute/python/mirrored_strategy.py | |
parent | 9d5b6e6627b04abbfecb5c7a95576e70f29dd7cf (diff) |
Add `num_gpus_per_worker` argument to MirroredStrategy.
PiperOrigin-RevId: 211008923
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index b4233a5eed..d1235b7afb 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -340,6 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus: number of GPUs. For local training, either specify `devices` or `num_gpus`. In distributed training, this must be specified as number of GPUs on each worker. + num_gpus_per_worker: number of GPUs per worker. This is the same as + `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be + specified. cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input @@ -349,6 +352,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def __init__(self, devices=None, num_gpus=None, + num_gpus_per_worker=None, cross_tower_ops=None, prefetch_on_device=None): super(MirroredStrategy, self).__init__() @@ -356,9 +360,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device # Rememeber num GPUs which might be needed by `configure` method. - self._num_gpus = num_gpus + if num_gpus is not None and num_gpus_per_worker is not None: + raise ValueError( + "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") + if num_gpus is not None: + self._num_gpus = num_gpus + else: + self._num_gpus = num_gpus_per_worker - self._initialize_local(num_gpus, devices) + self._initialize_local(self._num_gpus, devices) def _initialize_local(self, num_gpus, devices): """Initializes the object for local training.""" |