aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distribute/python/BUILD13
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py1
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy.py141
-rw-r--r--tensorflow/contrib/distribute/python/multi_worker_strategy_test.py64
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py1
-rw-r--r--tensorflow/python/training/distribute.py20
6 files changed, 238 insertions, 2 deletions
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index aaafc184bf..8dfcaf6032 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -87,6 +87,19 @@ py_library(
)
py_library(
+ name = "multi_worker_strategy",
+ srcs = ["multi_worker_strategy.py"],
+ visibility = ["//tensorflow:internal"],
+ deps = [
+ ":mirrored_strategy",
+ ":values",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ ],
+)
+
+py_library(
name = "one_device_strategy",
srcs = ["one_device_strategy.py"],
visibility = ["//tensorflow:internal"],
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 2e57b02583..8237b23dbb 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -80,6 +80,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
dict((d, i) for i, d in enumerate(devices)))
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ # TODO(yuefengz): consider setting the default device.
def _create_variable(self, next_creator, *args, **kwargs):
"""Create a mirrored variable. See `DistributionStrategy.scope`."""
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy.py b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
new file mode 100644
index 0000000000..a552b370eb
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/multi_worker_strategy.py
@@ -0,0 +1,141 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Classes implementing a mirrored DistributionStrategy for multiple workers."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from functools import partial
+
+from tensorflow.contrib.distribute.python import values
+from tensorflow.contrib.distribute.python.mirrored_strategy import MirroredStrategy
+from tensorflow.core.protobuf import cluster_pb2
+from tensorflow.python.training import device_util
+from tensorflow.python.training import server_lib
+from tensorflow.python.util import nest
+
+
+# TODO(yuefengz): support between-graph replication.
+# TODO(yuefengz): merge this class into its base class.
+# TODO(yuefengz): in some cases, we probably want to use configure method to
+# configure this class.
+# TODO(yuefengz): MirroredStrategy.worker_devices may be confusing after the
+# class is introduced.
+class MultiWorkerMirroredStrategy(MirroredStrategy):
+ """Mirrored strategy that works on multiple workers with in-graph replication.
+
+ There are several important concepts for distributed TensorFlow, e.g.
+ `client`, `job`, 'task', `cluster`, `in-graph replication` and
+ 'synchronous training' and they have already been defined in the
+ [TensorFlow's documentation](https://www.tensorflow.org/deploy/distributed).
+ The distribution strategy inherits these concepts as well and in addition to
+ that we also clarify several more concepts:
+ * **In-graph replication**: the `client` creates a single `tf.Graph` that
+ specifies tasks for devices on all workers. The `client` then creates a
+ client session which will talk to the `master` service of a `worker`. Then
+ the `master` will parition the graph and distribute the work to all
+ participating workers.
+ * **Worker**: A `worker` is a TensorFlow `task` that usually maps to one
+ physical machine. We will have multiple `worker`s with different `task`
+ index. They all do similar things except for one worker checkpointing model
+ variables, writing summaries, etc. in addition to its ordinary work.
+
+ This class maps one tower to one device on a worker. It mirrors all model
+ variables on all towers. For example, if you have two `worker`s and each
+ `worker` has 4 GPUs, it will create 8 copies of the model variables on these 8
+ GPUs. Then like in MirroredStrategy, each tower performs their computation
+ with their own copy of variables unless in cross-tower model where variable or
+ tensor reduction happens.
+ """
+
+ def __init__(self,
+ num_gpus_per_worker=1,
+ worker_job_name=None,
+ num_workers=None,
+ cluster=None,
+ cross_tower_ops=None,
+ prefetch_on_device=None):
+ """Initialize the strategy object.
+
+ Args:
+ num_gpus_per_worker: number of GPUs per work. If it is zero, the local
+ CPU will be used.
+ worker_job_name: the job name for `worker`, typically just 'worker'.
+ num_workers: the number of workers. If it is 0, it regenerates to
+ single-worker MirroredStrategy.
+ cluster: a `tf.train.ClusterSpec` object or a dict that can be used to
+ construct a `tf.train.ClusterSpec` object or a `tf.train.ClusterDef`
+ proto buffer. It is an alternative way to initialize this object.
+ cross_tower_ops: the cross tower ops to use. If None, a default one will
+ be used. If configure method is called, a best one for the configuration
+ will be chosen.
+ prefetch_on_device: a boolean to specify whether to prefetech input to
+ each worker's devices.
+
+ Raises:
+ ValueError: if got an unexpected `cluster`.
+ """
+ if cluster is None:
+ self._workers = [
+ '/job:%s/task:%d' % (worker_job_name, task_index)
+ for task_index in range(num_workers)
+ ]
+ else:
+ if isinstance(cluster, (dict, cluster_pb2.ClusterDef)):
+ cluster_spec = server_lib.ClusterSpec(cluster)
+ elif isinstance(cluster, server_lib.ClusterSpec):
+ cluster_spec = cluster
+ else:
+ raise ValueError(
+ "`cluster_spec' should be dict or a `tf.train.ClusterSpec` or a "
+ '`tf.train.ClusterDef` object')
+
+ self._workers = []
+ for job in sorted(cluster_spec.jobs):
+ for task in range(cluster_spec.num_tasks(job)):
+ self._workers.append('/job:%s/task:%d' % (job, task))
+
+ self._num_gpus_per_worker = num_gpus_per_worker
+ if num_gpus_per_worker > 0:
+ self._worker_device_map = {
+ worker: [
+ device_util.canonicalize(worker + '/device:GPU:%d' % gpu)
+ for gpu in range(num_gpus_per_worker)
+ ] for worker in self._workers
+ }
+ else:
+ self._worker_device_map = {
+ worker: [device_util.canonicalize(worker, '/device:CPU:0')]
+ for worker in self._workers
+ }
+ self._devices = nest.flatten(self._worker_device_map.values())
+
+ super(MultiWorkerMirroredStrategy, self).__init__(
+ devices=self._devices, prefetch_on_device=prefetch_on_device)
+
+ # Setting `_default_device` will add a device scope in the
+ # distribution.scope. We set the default device to the first worker. When
+ # users specify device under distribution.scope by
+ # with tf.device("/cpu:0"):
+ # ...
+ # their ops will end up on the cpu device of its first worker, e.g.
+ # "/job:worker/task:0/device:CPU:0". Note this is not used in tower mode.
+ self._default_device = self._workers[0]
+
+ def distribute_dataset(self, dataset_fn):
+ return values.MultiWorkerDataset(
+ partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
+ self._prefetch_on_device)
diff --git a/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
new file mode 100644
index 0000000000..ee7588163e
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/multi_worker_strategy_test.py
@@ -0,0 +1,64 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MultiWorkerMirroredStrategy."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import multi_worker_strategy
+from tensorflow.contrib.distribute.python import multi_worker_test_base
+from tensorflow.contrib.distribute.python import strategy_test_lib
+from tensorflow.python.eager import context
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.training import server_lib
+
+
+@test_util.with_c_api
+class MultiWorkerStrategyTest(multi_worker_test_base.MultiWorkerTestBase,
+ strategy_test_lib.DistributionTestBase):
+
+ def _get_distribution_strategy(self):
+ return multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster=server_lib.ClusterSpec({
+ 'worker': ['/job:worker/task:0', '/job:worker/task:1']
+ }),
+ num_gpus_per_worker=context.num_gpus())
+
+ def testMinimizeLossGraph(self):
+ self._test_minimize_loss_graph(self._get_distribution_strategy())
+
+
+class DeviceScopeTest(test.TestCase):
+ """Test the device scope of MultiWorkerMirroredStrategy."""
+
+ def testDeviceScope(self):
+ with context.graph_mode():
+ strategy = multi_worker_strategy.MultiWorkerMirroredStrategy(
+ cluster={'worker': ['/job:worker/task:0', '/job:worker/task:1']},
+ num_gpus_per_worker=context.num_gpus())
+ with strategy.scope():
+ a = constant_op.constant(1.)
+ with ops.device('/cpu:0'):
+ b = constant_op.constant(1.)
+ self.assertEqual(a.device, '/job:worker/task:0')
+ self.assertEqual(b.device, '/job:worker/task:0/device:CPU:0')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 64aa369201..09b6d4a515 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -40,6 +40,7 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
super(OneDeviceStrategy, self).__init__()
self._device = device
self._prefetch_on_device = prefetch_on_device
+ self._default_device = device
def _create_variable(self, next_creator, *args, **kwargs):
# No need to distinguish tower-local variables when not mirroring,
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index c16b05102e..21f81ee187 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -290,19 +290,31 @@ def _require_distribution_strategy_scope(distribution_strategy):
class _CurrentDistributionContext(object):
"""Context manager for setting the `DistributionStrategy` and var creator."""
- def __init__(self, distribution_strategy, var_creator_scope, var_scope=None):
+ def __init__(self,
+ distribution_strategy,
+ var_creator_scope,
+ var_scope=None,
+ default_device=None):
self._context = _CrossTowerThreadMode(distribution_strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
+ if default_device:
+ self._device_scope = ops.device(default_device)
+ else:
+ self._device_scope = None
def __enter__(self):
_push_per_thread_mode(self._context)
if self._var_scope:
self._var_scope.__enter__()
self._var_creator_scope.__enter__()
+ if self._device_scope:
+ self._device_scope.__enter__()
return self._context.distribution_strategy
def __exit__(self, exception_type, exception_value, traceback):
+ if self._device_scope:
+ self._device_scope.__exit__(exception_type, exception_value, traceback)
self._var_creator_scope.__exit__(exception_type, exception_value, traceback)
if self._var_scope:
self._var_scope.__exit__(exception_type, exception_value, traceback)
@@ -557,6 +569,9 @@ class DistributionStrategy(object):
# TODO(josh11b): List of towers with their worker and parameter devices
# (where the parameter devices may overlap in the ps case).
+ def __init__(self):
+ self._default_device = None
+
def scope(self):
"""Returns a context manager selecting this DistributionStrategy as current.
@@ -587,7 +602,8 @@ class DistributionStrategy(object):
self, variable_scope.variable_creator_scope(creator_with_resource_vars),
variable_scope.variable_scope(
variable_scope.get_variable_scope(),
- custom_getter=disable_partitioned_variables))
+ custom_getter=disable_partitioned_variables),
+ self._default_device)
def _create_variable(self, next_creator, *args, **kwargs):
# Note: should support "colocate_with" argument.