diff options
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r-- | tensorflow/python/training/distribute.py | 232 |
1 files changed, 57 insertions, 175 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py index 581db45e80..28c60ad809 100644 --- a/tensorflow/python/training/distribute.py +++ b/tensorflow/python/training/distribute.py @@ -21,7 +21,7 @@ from __future__ import print_function import threading from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.eager import context +from tensorflow.python.eager import context as eager_context from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops @@ -31,71 +31,11 @@ from tensorflow.python.ops import variable_scope from tensorflow.python.ops.losses import losses_impl from tensorflow.python.platform import tf_logging from tensorflow.python.training import device_util +from tensorflow.python.training import distribution_strategy_context from tensorflow.python.util import nest # ------------------------------------------------------------------------------ -# Internal API for setting the current thread mode as being either in a -# tower or cross-tower context for a particular distribution strategy. - - -class _ThreadMode(object): - - def __init__(self, dist, cross, tower): - self.distribution_strategy = dist - self.cross_tower_context = cross - self.tower_context = tower - - -class _CrossTowerThreadMode(_ThreadMode): - - def __init__(self, distribution_strategy): - _ThreadMode.__init__( - self, distribution_strategy, distribution_strategy, None) - - -class _InTowerThreadMode(_ThreadMode): - - def __init__(self, tower_ctx): - _ThreadMode.__init__( - self, tower_ctx.distribution_strategy, None, tower_ctx) - - -_per_thread_mode = threading.local() - - -def _push_per_thread_mode(context): - if not hasattr(_per_thread_mode, "stack"): - _per_thread_mode.stack = [] - _per_thread_mode.stack.append(context) - - -def _pop_per_thread_mode(): - _per_thread_mode.stack.pop(-1) - - -class _DefaultTowerThreadMode(_ThreadMode): - """Type of default value returned by `_get_per_thread_mode()`. - - Used when the thread-local stack is empty. - """ - - def __init__(self): - # _default_distribution_strategy and _default_tower_context are - # defined at the bottom of this file. - _ThreadMode.__init__( - self, _default_distribution_strategy, None, _default_tower_context) - - -def _get_per_thread_mode(): - try: - return _per_thread_mode.stack[-1] - except (AttributeError, IndexError): - # _default_tower_mode is defined at the bottom of this file. - return _default_tower_mode - - -# ------------------------------------------------------------------------------ # Context tracking whether in a distribution.update() or .update_non_slot() # call. @@ -128,96 +68,6 @@ class UpdateContext(object): # ------------------------------------------------------------------------------ -# Public API for accessing the current thread mode - - -def get_tower_context(): - """Returns the current TowerContext or None if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context (this function - will return the default TowerContext object); - 2. switches to cross-tower context (in which case this will return - None) when entering a `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context (and again - this function will return None). - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context, in a tower context you should use the - `TowerContext` API instead. - - Returns: - The current `TowerContext` object when in a tower context scope, else None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().tower_context - - -def get_cross_tower_context(): - """Returns the current DistributionStrategy if in a cross-tower context. - - Note that execution: - 1. starts in the default (single-tower) tower context; - 2. switches to cross-tower context when entering a - `with DistributionStrategy.scope():` block; - 3. switches to a (non-default) tower context inside - `call_for_each_tower(fn, ...)`; - 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then - inside `merge_fn` you are back in the cross-tower context. - - Note that you can also go directly from step 1 to 4 to switch to a - cross-tower context for the default `DistributionStrategy`. You may - also switch from the cross-tower context of 4 to a tower context by - calling `call_for_each_tower()`, jumping back to step 3. - - Most `DistributionStrategy` methods may only be executed in - a cross-tower context. - - Returns: - Returns the current `DistributionStrategy` object in a cross-tower - context, or None. - - Exactly one of `get_tower_context()` and `get_cross_tower_context()` - will return None in a particular block. - """ - return _get_per_thread_mode().cross_tower_context - - -def get_distribution_strategy(): - """Returns the current `DistributionStrategy` object. - - Prefer to use `get_tower_context()` or `get_cross_tower_context()` - instead when possible. - - Returns: - A `DistributionStrategy` object. Inside a - `with distribution_strategy.scope()` block, it returns - `distribution_strategy`, otherwise it returns the default - (single-tower) `DistributionStrategy` object. - """ - return _get_per_thread_mode().distribution_strategy - - -def has_distribution_strategy(): - """Return if there is a current non-default `DistributionStrategy`. - - Returns: - True if inside a `with distribution_strategy.scope():`. - """ - return get_distribution_strategy() is not _default_distribution_strategy - - -# ------------------------------------------------------------------------------ # Public utility functions. @@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy): if context.cross_tower_context is distribution_strategy: return # We have an error to report, figure out the right message. if context.distribution_strategy is not distribution_strategy: - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy): context = _get_per_thread_mode() if context.distribution_strategy is distribution_strategy: return # We have an error to report, figure out the right message. - if context.distribution_strategy is _default_distribution_strategy: + if (context.distribution_strategy is + distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access raise RuntimeError( 'Need to be inside "with distribution_strategy.scope()" for %s' % (distribution_strategy,)) @@ -295,7 +147,8 @@ class _CurrentDistributionContext(object): var_creator_scope, var_scope=None, default_device=None): - self._context = _CrossTowerThreadMode(distribution_strategy) + self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + distribution_strategy) self._var_creator_scope = var_creator_scope self._var_scope = var_scope if default_device: @@ -588,7 +441,7 @@ class DistributionStrategy(object): Returns: A context manager. """ - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): _require_cross_tower_context(self) return _SameScopeAgainContext(self) @@ -740,7 +593,7 @@ class DistributionStrategy(object): In eager mode, returns `None`. In graph mode, a list of ops to execute. Empty list if nothing to be done. """ - if context.executing_eagerly(): + if eager_context.executing_eagerly(): return else: return [] @@ -757,7 +610,7 @@ class DistributionStrategy(object): In eager mode, returns `None`. In graph mode, a list of ops to execute. Empty list if nothing to be done. """ - if context.executing_eagerly(): + if eager_context.executing_eagerly(): return else: return [] @@ -1077,9 +930,37 @@ class DistributionStrategy(object): def _worker_device_index(self): raise NotImplementedError("must be implemented in descendants") - def configure(self, session_config=None): - """Find the best configuration given a tensorflow session config.""" - del session_config + @property + def between_graph(self): + """Whether the strategy uses between-graph replication or not. + + This is expected to return a constant value that will not be changed + throughout its life cycle. + """ + raise NotImplementedError("must be implemented in descendants") + + def configure(self, + session_config=None, + cluster_spec=None, + task_type=None, + task_id=None): + """Configures the strategy class.""" + del session_config, cluster_spec, task_type, task_id + + @property + def should_init(self): + """Whether initialization is needed.""" + raise NotImplementedError("must be implemented in descendants") + + @property + def should_checkpoint(self): + """Whether checkpointing is needed.""" + raise NotImplementedError("must be implemented in descendants") + + @property + def should_save_summary(self): + """Whether saving summaries is needed.""" + raise NotImplementedError("must be implemented in descendants") # A note about the difference between the context managers @@ -1106,7 +987,8 @@ class TowerContext(object): def __init__(self, distribution_strategy, tower_id): self._distribution_strategy = distribution_strategy - self._thread_context = _InTowerThreadMode(self) + self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access + self) self._tower_id = tower_id def __enter__(self): @@ -1149,7 +1031,8 @@ class TowerContext(object): def _merge_call(self, merge_fn, *args, **kwargs): """Default implementation for single tower.""" _push_per_thread_mode( # thread-local, so not needed with multiple threads - _CrossTowerThreadMode(self._distribution_strategy)) + distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access + self._distribution_strategy)) try: return merge_fn(self._distribution_strategy, *args, **kwargs) finally: @@ -1196,7 +1079,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): def scope(self): """Context manager setting a variable creator and `self` as current.""" - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise RuntimeError("Must not nest DistributionStrategy scopes.") def creator(next_creator, *args, **kwargs): @@ -1277,6 +1160,7 @@ class _DefaultDistributionStrategy(DistributionStrategy): raise RuntimeError("worker_device_index() method unsupported by " "_DefaultDistributionStrategy.") + # ------------------------------------------------------------------------------ # Common operations @@ -1292,20 +1176,11 @@ def increment_var(v, amount=1): def merge_fn(dist, vm): return dist.group(dist.update(vm, update)) - tower_context = get_tower_context() + tower_context = distribution_strategy_context.get_tower_context() return tower_context.merge_call(merge_fn, v) # ------------------------------------------------------------------------------ -# Singletons - -_default_distribution_strategy = _DefaultDistributionStrategy() -_default_tower_context = TowerContext( - _default_distribution_strategy, tower_id=0) -_default_tower_mode = _DefaultTowerThreadMode() - - -# ------------------------------------------------------------------------------ # We haven't yet implemented deserialization for DistributedVariables. # So here we catch any attempts to deserialize variables # when using distribution strategies. @@ -1314,7 +1189,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn def _from_proto_fn(v, import_scope=None): - if has_distribution_strategy(): + if distribution_strategy_context.has_distribution_strategy(): raise NotImplementedError( "Deserialization of variables is not yet supported when using" "distributed strategies.") @@ -1323,3 +1198,10 @@ def _from_proto_fn(v, import_scope=None): resource_variable_ops._from_proto_fn = _from_proto_fn # pylint: enable=protected-access + + +#------------------------------------------------------------------------------- +# Shorthand for some methods from distribution_strategy_context. +_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access +_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access +_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access |