aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/distribute
diff options
context:
space:
mode:
authorGravatar Yuefeng Zhou <yuefengz@google.com>2018-09-21 11:12:21 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-21 11:17:52 -0700
commit61a9623ac31fd363aff8537df6c3b6073d721425 (patch)
tree266d3809cef361e1bc53599bb41af793d0952b8a /tensorflow/python/distribute
parentfb1335fd7b6bb753080bbfedec4d70aeacc4218a (diff)
In standalone client mode, only run hooks on one thread.
PiperOrigin-RevId: 214013965
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r--tensorflow/python/distribute/estimator_training.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index e17a598123..8daa34c885 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -182,6 +182,7 @@ def should_run_distribute_coordinator(config):
# pylint: disable=protected-access
if (not hasattr(config, '_distribute_coordinator_mode') or
config._distribute_coordinator_mode is None):
+ logging.info('Not using Distribute Coordinator.')
return False
if (not isinstance(config._distribute_coordinator_mode, six.string_types) or
config._distribute_coordinator_mode not in [
@@ -221,15 +222,28 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator = copy.deepcopy(estimator)
# pylint: disable=protected-access
local_estimator._config._train_distribute = strategy
- _init_run_config_from_worker_context(
- local_estimator._config, dc_context.get_current_worker_context())
+ context = dc_context.get_current_worker_context()
+ _init_run_config_from_worker_context(local_estimator._config, context)
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._train_distribution = strategy
# pylint: enable=protected-access
+ # In the standalone client, we don't need to run hooks on all threads
+ # because logging hooks on all threads may be too much on the screen; also
+ # tensor passed to one hook can only be fetched with the graph where the
+ # tensor is defined. Other hooks such as checkpointing hooks will added by
+ # MonitoredTrainingSession.
+ # TODO(yuefengz): Is there a hook that does need to run on all threads in
+ # standalone client mode?
+ if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access
+ dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief):
+ hooks = list(train_spec.hooks)
+ else:
+ hooks = []
local_estimator.train(
input_fn=train_spec.input_fn,
max_steps=train_spec.max_steps,
- hooks=list(train_spec.hooks))
+ hooks=hooks)
def _eval_fn(strategy):
"""Function for evaluator task."""
@@ -238,6 +252,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls):
local_estimator._config._eval_distribute = strategy
_init_run_config_from_worker_context(
local_estimator._config, dc_context.get_current_worker_context())
+ logging.info('Updated config: %s', str(vars(local_estimator._config)))
local_estimator._eval_distribution = strategy
executor = executor_cls(local_estimator, train_spec, eval_spec)