aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-08-16 21:31:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-16 21:35:13 -0700
commit439d8c4809139b163853fe87e8c5cdaba5d832eb (patch)
treeb599479c27353cd3c5a537c3fddc388d03fec2f9 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parenta1606d5e0f667fddd7f3f5705bda3aee5b3c2554 (diff)
Merge MultiWorkerMirroredStrategy into MirroredStrategy
PiperOrigin-RevId: 209099475
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py134
1 files changed, 119 insertions, 15 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index edd5c6d17a..6981449a4c 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
import contextlib
+from functools import partial
import threading
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
from tensorflow.contrib.distribute.python import values
+from tensorflow.core.protobuf import cluster_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
@@ -37,6 +39,7 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import coordinator
from tensorflow.python.training import device_util
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import server_lib
from tensorflow.python.util import nest
@@ -291,24 +294,112 @@ def _create_mirrored_variable(devices, real_mirrored_creator, *args, **kwargs):
class MirroredStrategy(distribute_lib.DistributionStrategy):
- """Mirrors vars to distribute across multiple devices on a single machine.
+ """Mirrors vars to distribute across multiple devices and machines.
+
+ This strategy uses one tower per device and sync replication for its multi-GPU
+ version.
+
+ When `cluster_spec` is given, it turns into the mulit-worker version that
+ works on multiple workers with in-graph replication.
+
+ There are several important concepts for distributed TensorFlow, e.g.
+ `client`, `job`, 'task', `cluster`, `in-graph replication` and
+ 'synchronous training' and they have already been defined in the
+ [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+ The distribution strategy inherits these concepts as well and in addition to
+ that we also clarify several more concepts:
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
+ specifies tasks for devices on all workers. The `client` then creates a
+ client session which will talk to the `master` service of a `worker`. Then
+ the `master` will partition the graph and distribute the work to all
+ participating workers.
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ physical machine. We will have multiple `worker`s with different `task`
+ index. They all do similar things except for one worker checkpointing model
+ variables, writing summaries, etc. in addition to its ordinary work.
+
+ The multi-worker version of this class maps one tower to one device on a
+ worker. It mirrors all model variables on all towers. For example, if you have
+ two `worker`s and each `worker` has 4 GPUs, it will create 8 copies of the
+ model variables on these 8 GPUs. Then like in MirroredStrategy, each tower
+ performs their computation with their own copy of variables unless in
+ cross-tower model where variable or tensor reduction happens.
- This strategy uses one tower per device and sync replication.
+ Args:
+ devices: a list of device strings.
+ num_gpus: number of GPUs. For local training, either specify `devices` or
+ `num_gpus`. In distributed training, this must be specified as number of
+ GPUs on each worker.
+ cluster_spec: if this is set, it turns into the multi-worker version and
+ `devices` must not be set but `num_gpus` must be set.
+ cross_tower_ops: optional, a descedant of `CrossTowerOps`. If this is not
+ set, the `configure` method will try to find the best one.
+ prefetch_on_device: optional boolean to specify whether to prefetch input
+ data to devices.
"""
def __init__(self,
devices=None,
num_gpus=None,
+ cluster_spec=None,
cross_tower_ops=None,
prefetch_on_device=None):
super(MirroredStrategy, self).__init__()
- # Convert `num_gpus` into `devices`, shouldn't specify both.
- if devices is None:
+
+ if cluster_spec:
+ if devices is not None:
+ raise ValueError("Specifying devices when `cluster_spec` is also given "
+ "is not supported in MirroredStrategy.")
+
+ # TODO(yuefengz): use the utility method to normalize cluster_spec.
+ if isinstance(cluster_spec, (dict, cluster_pb2.ClusterDef)):
+ cluster_spec = server_lib.ClusterSpec(cluster_spec)
+ elif not isinstance(cluster_spec, server_lib.ClusterSpec):
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ "`tf.train.ClusterDef` object")
+ self._cluster_spec = cluster_spec
+
+ self._workers = []
+ for job in sorted(cluster_spec.jobs):
+ for task in range(cluster_spec.num_tasks(job)):
+ self._workers.append("/job:%s/task:%d" % (job, task))
+
if num_gpus is None:
- num_gpus = context.num_gpus()
- devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
- elif num_gpus is not None:
- raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ raise ValueError("`num_gpus` is required if `cluster_spec` is given.")
+ self._num_gpus = num_gpus
+ if num_gpus > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + "/device:GPU:%d" % gpu)
+ for gpu in range(num_gpus)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, "/device:CPU:0")]
+ for worker in self._workers
+ }
+ devices = nest.flatten(self._worker_device_map)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
+ else:
+ self._cluster_spec = None
+ # Convert `num_gpus` into `devices`, shouldn't specify both.
+ if devices is None:
+ if num_gpus is None:
+ num_gpus = context.num_gpus()
+ devices = ["/device:GPU:%d" % d for d in range(num_gpus)]
+ elif num_gpus is not None:
+ raise ValueError("Must only specify one of `devices` and `num_gpus`.")
+ # TODO(yuefengz): consider setting the default device.
assert devices, "Must specify at least one device."
assert len(set(devices)) == len(devices), (
@@ -320,7 +411,6 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
{d: i for i, d in enumerate(devices)})
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
- # TODO(yuefengz): consider setting the default device.
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
@@ -357,9 +447,14 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
**kwargs)
def distribute_dataset(self, dataset_fn):
- return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn), self._devices,
- self._prefetch_on_device)
+ if self._cluster_spec:
+ return values.MultiWorkerDataset(
+ partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+ self._prefetch_on_device)
+ else:
+ return values.PerDeviceDataset(
+ self._call_dataset_fn(dataset_fn), self._devices,
+ self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
def _run_steps_on_dataset(self, fn, iterator, iterations,
@@ -444,10 +539,19 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
# in addition to PerDevice data.
return values.PerDevice({k: values.MapOutput(v) for k, v in index.items()})
- def configure(self, session_config=None):
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ del cluster_spec, task_type, task_id
if self._cross_tower_ops is None:
- self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
- self._devices, session_config=session_config)
+ if self._cluster_spec:
+ self._cross_tower_ops = cross_tower_ops_lib.MultiWorkerAllReduce(
+ self._workers, self._num_gpus)
+ else:
+ self._cross_tower_ops = cross_tower_ops_lib.choose_the_best(
+ self._devices, session_config=session_config)
def _get_cross_tower_ops(self):
if self._cross_tower_ops is None: