diff options
Diffstat (limited to 'tensorflow/python/training/distribution_strategy_context.py')
-rw-r--r-- | tensorflow/python/training/distribution_strategy_context.py | 203 |
1 files changed, 203 insertions, 0 deletions
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py new file mode 100644 index 0000000000..998b5c35ce --- /dev/null +++ b/tensorflow/python/training/distribution_strategy_context.py @@ -0,0 +1,203 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility to get distribution strategy related contexts.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.util.lazy_loader import LazyLoader + + +# There is a circular dependency between this and `distribute` module. So we +# load it lazily to workaround this. +distribute_lib = LazyLoader( + "distribute_lib", globals(), + "tensorflow.python.training.distribute") + +# ------------------------------------------------------------------------------ +# Internal API for setting the current thread mode as being either in a +# tower or cross-tower context for a particular distribution strategy. + + +class _ThreadMode(object): + + def __init__(self, dist, cross, tower): + self.distribution_strategy = dist + self.cross_tower_context = cross + self.tower_context = tower + + +class _CrossTowerThreadMode(_ThreadMode): + + def __init__(self, distribution_strategy): + _ThreadMode.__init__( + self, distribution_strategy, distribution_strategy, None) + + +class _InTowerThreadMode(_ThreadMode): + + def __init__(self, tower_ctx): + _ThreadMode.__init__( + self, tower_ctx.distribution_strategy, None, tower_ctx) + + +def _push_per_thread_mode(context): + ops.get_default_graph()._distribution_strategy_stack.append(context) # pylint: disable=protected-access + + +def _pop_per_thread_mode(): + ops.get_default_graph()._distribution_strategy_stack.pop(-1) # pylint: disable=protected-access + + +class _DefaultTowerThreadMode(_ThreadMode): + """Type of default value returned by `_get_per_thread_mode()`. + + Used when the thread-local stack is empty. + """ + + def __init__(self): + _ThreadMode.__init__(self, _get_default_distribution_strategy(), None, + _get_default_tower_context()) + + +def _get_per_thread_mode(): + try: + return ops.get_default_graph()._distribution_strategy_stack[-1] # pylint: disable=protected-access + except (AttributeError, IndexError): + return _get_default_tower_mode() + + +# ------------------------------------------------------------------------------ +# Public API for accessing the current thread mode + + +def get_tower_context(): + """Returns the current TowerContext or None if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context (this function + will return the default TowerContext object); + 2. switches to cross-tower context (in which case this will return + None) when entering a `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context (and again + this function will return None). + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context, in a tower context you should use the + `TowerContext` API instead. + + Returns: + The current `TowerContext` object when in a tower context scope, else None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().tower_context + + +def get_cross_tower_context(): + """Returns the current DistributionStrategy if in a cross-tower context. + + Note that execution: + 1. starts in the default (single-tower) tower context; + 2. switches to cross-tower context when entering a + `with DistributionStrategy.scope():` block; + 3. switches to a (non-default) tower context inside + `call_for_each_tower(fn, ...)`; + 4. if `fn` calls `get_tower_context()->merge_call(merge_fn, ...)`, then + inside `merge_fn` you are back in the cross-tower context. + + Note that you can also go directly from step 1 to 4 to switch to a + cross-tower context for the default `DistributionStrategy`. You may + also switch from the cross-tower context of 4 to a tower context by + calling `call_for_each_tower()`, jumping back to step 3. + + Most `DistributionStrategy` methods may only be executed in + a cross-tower context. + + Returns: + Returns the current `DistributionStrategy` object in a cross-tower + context, or None. + + Exactly one of `get_tower_context()` and `get_cross_tower_context()` + will return None in a particular block. + """ + return _get_per_thread_mode().cross_tower_context + + +def get_distribution_strategy(): + """Returns the current `DistributionStrategy` object. + + Prefer to use `get_tower_context()` or `get_cross_tower_context()` + instead when possible. + + Returns: + A `DistributionStrategy` object. Inside a + `with distribution_strategy.scope()` block, it returns + `distribution_strategy`, otherwise it returns the default + (single-tower) `DistributionStrategy` object. + """ + return _get_per_thread_mode().distribution_strategy + + +def has_distribution_strategy(): + """Return if there is a current non-default `DistributionStrategy`. + + Returns: + True if inside a `with distribution_strategy.scope():`. + """ + return get_distribution_strategy() is not _get_default_distribution_strategy() + + +# ------------------------------------------------------------------------------ +# Defaults that are used when no distribution strategy is explicitly created. +# We create them lazily in a function so that we can workaround the circular +# dependency on distribute_lib. See lazy loader at the top of this file. + +_defaults = { + "distribution_strategy": None, + "tower_context": None, + "tower_mode": None +} + + +def _get_default_distribution_strategy(): + if _defaults["distribution_strategy"] is None: + _defaults["distribution_strategy"] = ( + distribute_lib._DefaultDistributionStrategy()) # pylint: disable=protected-access + return _defaults["distribution_strategy"] + + +def _get_default_tower_context(): + if _defaults["tower_context"] is None: + _defaults["tower_context"] = distribute_lib.TowerContext( + _get_default_distribution_strategy(), tower_id=0) + return _defaults["tower_context"] + + +def _get_default_tower_mode(): + if _defaults["tower_mode"] is None: + _defaults["tower_mode"] = _DefaultTowerThreadMode() + return _defaults["tower_mode"] |