aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/distribute/python/mirrored_strategy.py
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-07-27 02:12:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-27 02:16:07 -0700
commitdc437d53a395438070739c3d509fa0c21b3bffbb (patch)
treedce8640d2a19a7facfbd42890eb6cf37f4c81220 /tensorflow/contrib/distribute/python/mirrored_strategy.py
parent4af0e25bfe0791f0ec3e9262c8d5051415bf026e (diff)
Add parameter server distribution.
PiperOrigin-RevId: 206289143
Diffstat (limited to 'tensorflow/contrib/distribute/python/mirrored_strategy.py')
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py297
1 files changed, 155 insertions, 142 deletions
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index dcbc6b0878..eb2d102012 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import contextlib
import threading
-import six
from tensorflow.contrib.distribute.python import cross_tower_ops as cross_tower_ops_lib
from tensorflow.contrib.distribute.python import shared_variable_creator
@@ -60,6 +59,156 @@ class _RequestedStop(Exception):
pass
+# Make _call_for_each_tower and _reduce_non_distributed_value not members of
+# MirroredStrategy so that they are generally not allowed to use anything
+# specific to MirroredStrategy and thus can be shared with other distribution
+# strategies.
+
+
+# TODO(yuefengz): maybe create a common class for those who need to call this
+# _call_for_each_tower.
+def _call_for_each_tower(distribution, fn, *args, **kwargs):
+ """Run `fn` in separate threads, once per tower/worker device.
+
+ Args:
+ distribution: the DistributionStrategy object.
+ fn: function to run (will be run once per device, each in its own thread).
+ *args: positional arguments for `fn`
+ **kwargs: keyword arguments for `fn`.
+ `"run_concurrently"`: Boolean indicating whether executions of `fn`
+ can be run concurrently (under eager execution only), defaults to
+ `True`.
+
+ Returns:
+ Merged return value of `fn` across all towers.
+
+ Raises:
+ RuntimeError: If fn() calls get_tower_context().merge_call() a different
+ number of times from the available devices.
+ """
+ run_concurrently = kwargs.pop("run_concurrently", True)
+ if not context.executing_eagerly():
+ # Lots of TF library code isn't thread-safe in graph mode, and
+ # there is little to be gained by turning on multithreading when
+ # constructing a graph.
+ run_concurrently = False
+ # Needed for per-thread device, etc. contexts in graph mode.
+ ops.get_default_graph().switch_to_thread_local()
+ elif run_concurrently is None:
+ run_concurrently = True
+
+ coord = coordinator.Coordinator(clean_stop_exception_types=(_RequestedStop,))
+
+ shared_variable_store = {}
+
+ # TODO(isaprykin): Create these threads once instead of during every run()
+ # call.
+ threads = []
+ for index, d in enumerate(distribution.worker_devices):
+ variable_creator_fn = shared_variable_creator.make_fn(
+ shared_variable_store, index)
+ t = MirroredStrategy._MirroredTowerThread( # pylint: disable=protected-access
+ distribution, coord, d, variable_creator_fn, fn,
+ *values.select_device(d, args), **values.select_device(d, kwargs))
+ threads.append(t)
+
+ for t in threads:
+ t.start()
+
+ # When `fn` starts `should_run` event is set on _MirroredTowerThread
+ # (`MTT`) threads. The execution waits until
+ # `MTT.has_paused` is set, which indicates that either `fn` is
+ # complete or a `get_tower_context().merge_call()` is called. If `fn` is
+ # complete, then `MTT.done` is set to True. Otherwise, arguments
+ # of `get_tower_context().merge_call` from all paused threads are grouped
+ # and the `merge_fn` is performed. Results of the
+ # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
+ # Each such `get_tower_context().merge_call` call returns the
+ # `MTT.merge_result` for that thread when `MTT.should_run` event
+ # is reset again. Execution of `fn` resumes.
+
+ try:
+ with coord.stop_on_exception():
+ all_done = False
+ while not all_done and not coord.should_stop():
+ done = []
+ if run_concurrently:
+ for t in threads:
+ t.should_run.set()
+ for t in threads:
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ else:
+ for t in threads:
+ t.should_run.set()
+ t.has_paused.wait()
+ t.has_paused.clear()
+ if coord.should_stop():
+ return None
+ done.append(t.done)
+ if coord.should_stop():
+ return None
+ all_done = all(done)
+ if not all_done:
+ if any(done):
+ raise RuntimeError("Some towers made a different number of "
+ "tower_context().merge_call() calls.")
+ # get_tower_context().merge_call() case
+ merge_args = values.regroup({t.device: t.merge_args for t in threads})
+ merge_kwargs = values.regroup(
+ {t.device: t.merge_kwargs for t in threads})
+ # We capture the name_scope of the MTT when we call merge_fn
+ # to ensure that if we have opened a name scope in the MTT,
+ # it will be respected when executing the merge function. We only
+ # capture the name_scope from the first MTT and assume it is
+ # the same for all other MTTs.
+ mtt_captured_name_scope = threads[0].captured_name_scope
+ with ops.name_scope(mtt_captured_name_scope):
+ merge_result = threads[0].merge_fn(distribution, *merge_args,
+ **merge_kwargs)
+ for t in threads:
+ t.merge_result = values.select_device(t.device, merge_result)
+ finally:
+ for t in threads:
+ t.should_run.set()
+ coord.join(threads)
+
+ return values.regroup({t.device: t.main_result for t in threads})
+
+
+def _reduce_non_distributed_value(distribution, aggregation, value,
+ destinations):
+ """Reduce a non-DistributedValue `value` to `destinations`."""
+ if isinstance(value, values.DistributedValues):
+ raise ValueError("You are passing a `DistributedValue` to "
+ "`_reduce_non_distributed_value`, which is not allowed.")
+
+ if value == 0:
+ return 0
+ if aggregation == variable_scope.VariableAggregation.MEAN:
+ return distribution.broadcast(value, destinations)
+
+ cross_tower_ops_lib.validate_destinations(destinations)
+ if (len(distribution.worker_devices) != 1 or
+ not cross_tower_ops_lib.check_destinations(destinations)):
+ raise ValueError("A non-DistributedValues value cannot be reduced with the "
+ "given aggregation.")
+ # TODO(anjalisridhar): Moves these methods to a device utility file?
+ devices = cross_tower_ops_lib.get_devices_from(destinations)
+ if len(devices) == 1:
+ with ops.device(devices[0]):
+ return array_ops.identity(value)
+ else:
+ value_updates = {}
+ for d in devices:
+ with ops.device(d):
+ value_updates[d] = array_ops.identity(value)
+ return values.Mirrored(value_updates)
+
+
class MirroredStrategy(distribute_lib.DistributionStrategy):
"""Mirrors vars to distribute across multiple devices on a single machine.
@@ -198,116 +347,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
self._devices)
def _call_for_each_tower(self, fn, *args, **kwargs):
- """Run `fn` in separate threads, once per tower/worker device.
-
- Args:
- fn: function to run (will be run once per device, each in its own thread).
- *args: positional arguments for `fn`
- **kwargs: keyword arguments for `fn`.
- `"run_concurrently"`: Boolean indicating whether executions of `fn`
- can be run concurrently (under eager execution only), defaults to
- `True`.
-
- Returns:
- Merged return value of `fn` across all towers.
-
- Raises:
- RuntimeError: If fn() calls get_tower_context().merge_call() a different
- number of times for when called for different devices.
- """
- run_concurrently = kwargs.pop("run_concurrently", True)
- if not context.executing_eagerly():
- # Lots of TF library code isn't thread-safe in graph mode, and
- # there is little to be gained by turning on multithreading when
- # constructing a graph.
- run_concurrently = False
- # Needed for per-thread device, etc. contexts in graph mode.
- ops.get_default_graph().switch_to_thread_local()
- elif run_concurrently is None:
- run_concurrently = True
-
- coord = coordinator.Coordinator(
- clean_stop_exception_types=(_RequestedStop,))
-
- shared_variable_store = {}
-
- # TODO(isaprykin): Create these threads once instead of during every run()
- # call.
- threads = []
- for index, d in enumerate(self._devices):
- variable_creator_fn = shared_variable_creator.make_fn(
- shared_variable_store, index)
- t = MirroredStrategy._MirroredTowerThread(
- self, coord, d, variable_creator_fn, fn,
- *values.select_device(d, args), **values.select_device(d, kwargs))
- threads.append(t)
-
- for t in threads:
- t.start()
-
- # When `fn` starts `should_run` event is set on _MirroredTowerThread
- # (`MTT`) threads. The execution waits until
- # `MTT.has_paused` is set, which indicates that either `fn` is
- # complete or a `get_tower_context().merge_call()` is called. If `fn` is
- # complete, then `MTT.done` is set to True. Otherwise, arguments
- # of `get_tower_context().merge_call` from all paused threads are grouped
- # and the `merge_fn` is performed. Results of the
- # `get_tower_context().merge_call` are then set to `MTT.merge_result`.
- # Each such `get_tower_context().merge_call` call returns the
- # `MTT.merge_result` for that thread when `MTT.should_run` event
- # is reset again. Execution of `fn` resumes.
-
- try:
- with coord.stop_on_exception():
- all_done = False
- while not all_done and not coord.should_stop():
- done = []
- if run_concurrently:
- for t in threads:
- t.should_run.set()
- for t in threads:
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- else:
- for t in threads:
- t.should_run.set()
- t.has_paused.wait()
- t.has_paused.clear()
- if coord.should_stop():
- return None
- done.append(t.done)
- if coord.should_stop():
- return None
- all_done = all(done)
- if not all_done:
- if any(done):
- raise RuntimeError("Some towers made a different number of "
- "tower_context().merge_call() calls.")
- # get_tower_context().merge_call() case
- merge_args = values.regroup(
- {t.device: t.merge_args for t in threads})
- merge_kwargs = values.regroup(
- {t.device: t.merge_kwargs for t in threads})
- # We capture the name_scope of the MTT when we call merge_fn
- # to ensure that if we have opened a name scope in the MTT,
- # it will be respected when executing the merge function. We only
- # capture the name_scope from the first MTT and assume it is
- # the same for all other MTTs.
- mtt_captured_name_scope = threads[0].captured_name_scope
- with ops.name_scope(mtt_captured_name_scope):
- merge_result = threads[0].merge_fn(
- self, *merge_args, **merge_kwargs)
- for t in threads:
- t.merge_result = values.select_device(t.device, merge_result)
- finally:
- for t in threads:
- t.should_run.set()
- coord.join(threads)
-
- return values.regroup({t.device: t.main_result for t in threads})
+ return _call_for_each_tower(self, fn, *args, **kwargs)
def map(self, map_over, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
@@ -337,29 +377,9 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _reduce(self, aggregation, value, destinations):
assert not isinstance(value, values.Mirrored)
- if not isinstance(value, values.PerDevice):
- if value == 0:
- return 0
- if aggregation == variable_scope.VariableAggregation.MEAN:
- return self._broadcast(value, destinations)
-
- cross_tower_ops_lib.validate_destinations(destinations)
- if len(self._devices) == 1:
- if destinations:
- # TODO(anjalisridhar): Moves these methods to a device utility file?
- devices = cross_tower_ops_lib.get_devices_from(destinations)
- if len(devices) == 1:
- with ops.device(devices[0]):
- return array_ops.identity(value)
- else:
- value_updates = {}
- for d in devices:
- with ops.device(d):
- value_updates[d] = array_ops.identity(value)
- return values.Mirrored(value_updates)
- raise ValueError("A non PerDevice value cannot be reduced with the given "
- "aggregation.")
-
+ if not isinstance(value, values.DistributedValues):
+ return _reduce_non_distributed_value(self, aggregation, value,
+ destinations)
return self._get_cross_tower_ops().reduce(
aggregation, value, destinations=destinations)
@@ -433,15 +453,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
def _get_devices_from(self, colocate_with=None):
if colocate_with is None:
return self._devices
- elif isinstance(colocate_with, values.DistributedValues):
- # pylint: disable=protected-access
- return list(colocate_with._index.keys())
- elif isinstance(colocate_with, six.string_types):
- return [device_util.resolve(colocate_with)]
- elif isinstance(colocate_with, list):
- return [device_util.resolve(d) for d in colocate_with]
else:
- return colocate_with
+ return cross_tower_ops_lib.get_devices_from(colocate_with)
class _MirroredTowerThread(threading.Thread):
"""A thread that runs() a function on a device."""