aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/training/distribute.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/training/distribute.py')
-rw-r--r--tensorflow/python/training/distribute.py232
1 files changed, 57 insertions, 175 deletions
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 581db45e80..28c60ad809 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -21,7 +21,7 @@ from __future__ import print_function
import threading
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
+from tensorflow.python.eager import context as eager_context
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
@@ -31,71 +31,11 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util import nest
# ------------------------------------------------------------------------------
-# Internal API for setting the current thread mode as being either in a
-# tower or cross-tower context for a particular distribution strategy.
-
-
-class _ThreadMode(object):
-
- def __init__(self, dist, cross, tower):
- self.distribution_strategy = dist
- self.cross_tower_context = cross
- self.tower_context = tower
-
-
-class _CrossTowerThreadMode(_ThreadMode):
-
- def __init__(self, distribution_strategy):
- _ThreadMode.__init__(
- self, distribution_strategy, distribution_strategy, None)
-
-
-class _InTowerThreadMode(_ThreadMode):
-
- def __init__(self, tower_ctx):
- _ThreadMode.__init__(
- self, tower_ctx.distribution_strategy, None, tower_ctx)
-
-
-_per_thread_mode = threading.local()
-
-
-def _push_per_thread_mode(context):
- if not hasattr(_per_thread_mode, "stack"):
- _per_thread_mode.stack = []
- _per_thread_mode.stack.append(context)
-
-
-def _pop_per_thread_mode():
- _per_thread_mode.stack.pop(-1)
-
-
-class _DefaultTowerThreadMode(_ThreadMode):
- """Type of default value returned by `_get_per_thread_mode()`.
-
- Used when the thread-local stack is empty.
- """
-
- def __init__(self):
- # _default_distribution_strategy and _default_tower_context are
- # defined at the bottom of this file.
- _ThreadMode.__init__(
- self, _default_distribution_strategy, None, _default_tower_context)
-
-
-def _get_per_thread_mode():
- try:
- return _per_thread_mode.stack[-1]
- except (AttributeError, IndexError):
- # _default_tower_mode is defined at the bottom of this file.
- return _default_tower_mode
-
-
-# ------------------------------------------------------------------------------
# Context tracking whether in a distribution.update() or .update_non_slot()
# call.
@@ -128,96 +68,6 @@ class UpdateContext(object):
# ------------------------------------------------------------------------------
-# Public API for accessing the current thread mode
-
-
-def get_tower_context():
- """Returns the current TowerContext or None if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context (this function
- will return the default TowerContext object);
- 2. switches to cross-tower context (in which case this will return
- None) when entering a `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context (and again
- this function will return None).
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context, in a tower context you should use the
- `TowerContext` API instead.
-
- Returns:
- The current `TowerContext` object when in a tower context scope, else None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().tower_context
-
-
-def get_cross_tower_context():
- """Returns the current DistributionStrategy if in a cross-tower context.
-
- Note that execution:
- 1. starts in the default (single-tower) tower context;
- 2. switches to cross-tower context when entering a
- `with DistributionStrategy.scope():` block;
- 3. switches to a (non-default) tower context inside
- `call_for_each_tower(fn, ...)`;
- 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then
- inside `merge_fn` you are back in the cross-tower context.
-
- Note that you can also go directly from step 1 to 4 to switch to a
- cross-tower context for the default `DistributionStrategy`. You may
- also switch from the cross-tower context of 4 to a tower context by
- calling `call_for_each_tower()`, jumping back to step 3.
-
- Most `DistributionStrategy` methods may only be executed in
- a cross-tower context.
-
- Returns:
- Returns the current `DistributionStrategy` object in a cross-tower
- context, or None.
-
- Exactly one of `get_tower_context()` and `get_cross_tower_context()`
- will return None in a particular block.
- """
- return _get_per_thread_mode().cross_tower_context
-
-
-def get_distribution_strategy():
- """Returns the current `DistributionStrategy` object.
-
- Prefer to use `get_tower_context()` or `get_cross_tower_context()`
- instead when possible.
-
- Returns:
- A `DistributionStrategy` object. Inside a
- `with distribution_strategy.scope()` block, it returns
- `distribution_strategy`, otherwise it returns the default
- (single-tower) `DistributionStrategy` object.
- """
- return _get_per_thread_mode().distribution_strategy
-
-
-def has_distribution_strategy():
- """Return if there is a current non-default `DistributionStrategy`.
-
- Returns:
- True if inside a `with distribution_strategy.scope():`.
- """
- return get_distribution_strategy() is not _default_distribution_strategy
-
-
-# ------------------------------------------------------------------------------
# Public utility functions.
@@ -239,7 +89,8 @@ def _require_cross_tower_context(distribution_strategy):
if context.cross_tower_context is distribution_strategy: return
# We have an error to report, figure out the right message.
if context.distribution_strategy is not distribution_strategy:
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -272,7 +123,8 @@ def _require_distribution_strategy_scope(distribution_strategy):
context = _get_per_thread_mode()
if context.distribution_strategy is distribution_strategy: return
# We have an error to report, figure out the right message.
- if context.distribution_strategy is _default_distribution_strategy:
+ if (context.distribution_strategy is
+ distribution_strategy_context._get_default_distribution_strategy()): # pylint: disable=protected-access
raise RuntimeError(
'Need to be inside "with distribution_strategy.scope()" for %s' %
(distribution_strategy,))
@@ -295,7 +147,8 @@ class _CurrentDistributionContext(object):
var_creator_scope,
var_scope=None,
default_device=None):
- self._context = _CrossTowerThreadMode(distribution_strategy)
+ self._context = distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ distribution_strategy)
self._var_creator_scope = var_creator_scope
self._var_scope = var_scope
if default_device:
@@ -588,7 +441,7 @@ class DistributionStrategy(object):
Returns:
A context manager.
"""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
_require_cross_tower_context(self)
return _SameScopeAgainContext(self)
@@ -740,7 +593,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -757,7 +610,7 @@ class DistributionStrategy(object):
In eager mode, returns `None`.
In graph mode, a list of ops to execute. Empty list if nothing to be done.
"""
- if context.executing_eagerly():
+ if eager_context.executing_eagerly():
return
else:
return []
@@ -1077,9 +930,37 @@ class DistributionStrategy(object):
def _worker_device_index(self):
raise NotImplementedError("must be implemented in descendants")
- def configure(self, session_config=None):
- """Find the best configuration given a tensorflow session config."""
- del session_config
+ @property
+ def between_graph(self):
+ """Whether the strategy uses between-graph replication or not.
+
+ This is expected to return a constant value that will not be changed
+ throughout its life cycle.
+ """
+ raise NotImplementedError("must be implemented in descendants")
+
+ def configure(self,
+ session_config=None,
+ cluster_spec=None,
+ task_type=None,
+ task_id=None):
+ """Configures the strategy class."""
+ del session_config, cluster_spec, task_type, task_id
+
+ @property
+ def should_init(self):
+ """Whether initialization is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_checkpoint(self):
+ """Whether checkpointing is needed."""
+ raise NotImplementedError("must be implemented in descendants")
+
+ @property
+ def should_save_summary(self):
+ """Whether saving summaries is needed."""
+ raise NotImplementedError("must be implemented in descendants")
# A note about the difference between the context managers
@@ -1106,7 +987,8 @@ class TowerContext(object):
def __init__(self, distribution_strategy, tower_id):
self._distribution_strategy = distribution_strategy
- self._thread_context = _InTowerThreadMode(self)
+ self._thread_context = distribution_strategy_context._InTowerThreadMode( # pylint: disable=protected-access
+ self)
self._tower_id = tower_id
def __enter__(self):
@@ -1149,7 +1031,8 @@ class TowerContext(object):
def _merge_call(self, merge_fn, *args, **kwargs):
"""Default implementation for single tower."""
_push_per_thread_mode( # thread-local, so not needed with multiple threads
- _CrossTowerThreadMode(self._distribution_strategy))
+ distribution_strategy_context._CrossTowerThreadMode( # pylint: disable=protected-access
+ self._distribution_strategy))
try:
return merge_fn(self._distribution_strategy, *args, **kwargs)
finally:
@@ -1196,7 +1079,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def scope(self):
"""Context manager setting a variable creator and `self` as current."""
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise RuntimeError("Must not nest DistributionStrategy scopes.")
def creator(next_creator, *args, **kwargs):
@@ -1277,6 +1160,7 @@ class _DefaultDistributionStrategy(DistributionStrategy):
raise RuntimeError("worker_device_index() method unsupported by "
"_DefaultDistributionStrategy.")
+
# ------------------------------------------------------------------------------
# Common operations
@@ -1292,20 +1176,11 @@ def increment_var(v, amount=1):
def merge_fn(dist, vm):
return dist.group(dist.update(vm, update))
- tower_context = get_tower_context()
+ tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
# ------------------------------------------------------------------------------
-# Singletons
-
-_default_distribution_strategy = _DefaultDistributionStrategy()
-_default_tower_context = TowerContext(
- _default_distribution_strategy, tower_id=0)
-_default_tower_mode = _DefaultTowerThreadMode()
-
-
-# ------------------------------------------------------------------------------
# We haven't yet implemented deserialization for DistributedVariables.
# So here we catch any attempts to deserialize variables
# when using distribution strategies.
@@ -1314,7 +1189,7 @@ _original_from_proto = resource_variable_ops._from_proto_fn
def _from_proto_fn(v, import_scope=None):
- if has_distribution_strategy():
+ if distribution_strategy_context.has_distribution_strategy():
raise NotImplementedError(
"Deserialization of variables is not yet supported when using"
"distributed strategies.")
@@ -1323,3 +1198,10 @@ def _from_proto_fn(v, import_scope=None):
resource_variable_ops._from_proto_fn = _from_proto_fn
# pylint: enable=protected-access
+
+
+#-------------------------------------------------------------------------------
+# Shorthand for some methods from distribution_strategy_context.
+_push_per_thread_mode = distribution_strategy_context._push_per_thread_mode # pylint: disable=protected-access
+_get_per_thread_mode = distribution_strategy_context._get_per_thread_mode # pylint: disable=protected-access
+_pop_per_thread_mode = distribution_strategy_context._pop_per_thread_mode # pylint: disable=protected-access