diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-08-30 17:09:50 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-30 17:14:53 -0700 |
commit | 11f970c61bc21ae81b76cdee58871f0509ec9f0f (patch) | |
tree | 6de4bc8c2ac06737629b8aaf09cfbc20db630f32 | |
parent | 9d5b6e6627b04abbfecb5c7a95576e70f29dd7cf (diff) |
Add `num_gpus_per_worker` argument to MirroredStrategy.
PiperOrigin-RevId: 211008923
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy.py | 14 | ||||
-rw-r--r-- | tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py | 3 |
2 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py index b4233a5eed..d1235b7afb 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py @@ -340,6 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): num_gpus: number of GPUs. For local training, either specify `devices` or `num_gpus`. In distributed training, this must be specified as number of GPUs on each worker. + num_gpus_per_worker: number of GPUs per worker. This is the same as + `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be + specified. cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not set, the `configure` method will try to find the best one. prefetch_on_device: optional boolean to specify whether to prefetch input @@ -349,6 +352,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): def __init__(self, devices=None, num_gpus=None, + num_gpus_per_worker=None, cross_tower_ops=None, prefetch_on_device=None): super(MirroredStrategy, self).__init__() @@ -356,9 +360,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): self._cross_tower_ops = cross_tower_ops self._prefetch_on_device = prefetch_on_device # Rememeber num GPUs which might be needed by `configure` method. - self._num_gpus = num_gpus + if num_gpus is not None and num_gpus_per_worker is not None: + raise ValueError( + "You cannot specify both `num_gpus` and `num_gpus_per_worker`.") + if num_gpus is not None: + self._num_gpus = num_gpus + else: + self._num_gpus = num_gpus_per_worker - self._initialize_local(num_gpus, devices) + self._initialize_local(self._num_gpus, devices) def _initialize_local(self, num_gpus, devices): """Initializes the object for local training.""" diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index 830681a4ce..c6894e9013 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -1385,7 +1385,8 @@ class MultiWorkerMirroredStrategyTestWithChief( cls._default_target = "grpc://" + cls._cluster_spec["chief"][0] def testMinimizeLossGraph(self): - strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus()) + strategy = mirrored_strategy.MirroredStrategy( + num_gpus_per_worker=context.num_gpus()) strategy.configure(cluster_spec=self._cluster_spec) self._test_minimize_loss_graph(strategy, learning_rate=0.05) |