aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-30 17:09:50 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 17:14:53 -0700
commit11f970c61bc21ae81b76cdee58871f0509ec9f0f (patch)
tree6de4bc8c2ac06737629b8aaf09cfbc20db630f32
parent9d5b6e6627b04abbfecb5c7a95576e70f29dd7cf (diff)
Add `num_gpus_per_worker` argument to MirroredStrategy.
PiperOrigin-RevId: 211008923
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py14
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py3
2 files changed, 14 insertions, 3 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index b4233a5eed..d1235b7afb 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -340,6 +340,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus: number of GPUs. For local training, either specify `devices` or
`num_gpus`. In distributed training, this must be specified as number of
GPUs on each worker.
+ num_gpus_per_worker: number of GPUs per worker. This is the same as
+ `num_gpus` and only one of `num_gpus` and `num_gpus_per_worker` can be
+ specified.
cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
@@ -349,6 +352,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def __init__(self,
devices=None,
num_gpus=None,
+ num_gpus_per_worker=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
@@ -356,9 +360,15 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
# Rememeber num GPUs which might be needed by `configure` method.
- self._num_gpus = num_gpus
+ if num_gpus is not None and num_gpus_per_worker is not None:
+ raise ValueError(
+ "You cannot specify both `num_gpus` and `num_gpus_per_worker`.")
+ if num_gpus is not None:
+ self._num_gpus = num_gpus
+ else:
+ self._num_gpus = num_gpus_per_worker
- self._initialize_local(num_gpus, devices)
+ self._initialize_local(self._num_gpus, devices)
def _initialize_local(self, num_gpus, devices):
"""Initializes the object for local training."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 830681a4ce..c6894e9013 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -1385,7 +1385,8 @@ class MultiWorkerMirroredStrategyTestWithChief(
cls._default_target = "grpc://" + cls._cluster_spec["chief"][0]
def testMinimizeLossGraph(self):
- strategy = mirrored_strategy.MirroredStrategy(num_gpus=context.num_gpus())
+ strategy = mirrored_strategy.MirroredStrategy(
+ num_gpus_per_worker=context.num_gpus())
strategy.configure(cluster_spec=self._cluster_spec)
self._test_minimize_loss_graph(strategy, learning_rate=0.05)