diff options
author | Yuefeng Zhou <yuefengz@google.com> | 2018-09-21 11:12:21 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-21 11:17:52 -0700 |
commit | 61a9623ac31fd363aff8537df6c3b6073d721425 (patch) | |
tree | 266d3809cef361e1bc53599bb41af793d0952b8a /tensorflow/python/distribute | |
parent | fb1335fd7b6bb753080bbfedec4d70aeacc4218a (diff) |
In standalone client mode, only run hooks on one thread.
PiperOrigin-RevId: 214013965
Diffstat (limited to 'tensorflow/python/distribute')
-rw-r--r-- | tensorflow/python/distribute/estimator_training.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py index e17a598123..8daa34c885 100644 --- a/tensorflow/python/distribute/estimator_training.py +++ b/tensorflow/python/distribute/estimator_training.py @@ -182,6 +182,7 @@ def should_run_distribute_coordinator(config): # pylint: disable=protected-access if (not hasattr(config, '_distribute_coordinator_mode') or config._distribute_coordinator_mode is None): + logging.info('Not using Distribute Coordinator.') return False if (not isinstance(config._distribute_coordinator_mode, six.string_types) or config._distribute_coordinator_mode not in [ @@ -221,15 +222,28 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): local_estimator = copy.deepcopy(estimator) # pylint: disable=protected-access local_estimator._config._train_distribute = strategy - _init_run_config_from_worker_context( - local_estimator._config, dc_context.get_current_worker_context()) + context = dc_context.get_current_worker_context() + _init_run_config_from_worker_context(local_estimator._config, context) + logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._train_distribution = strategy # pylint: enable=protected-access + # In the standalone client, we don't need to run hooks on all threads + # because logging hooks on all threads may be too much on the screen; also + # tensor passed to one hook can only be fetched with the graph where the + # tensor is defined. Other hooks such as checkpointing hooks will added by + # MonitoredTrainingSession. + # TODO(yuefengz): Is there a hook that does need to run on all threads in + # standalone client mode? + if (run_config._distribute_coordinator_mode == # pylint: disable=protected-access + dc.CoordinatorMode.INDEPENDENT_WORKER or context.is_chief): + hooks = list(train_spec.hooks) + else: + hooks = [] local_estimator.train( input_fn=train_spec.input_fn, max_steps=train_spec.max_steps, - hooks=list(train_spec.hooks)) + hooks=hooks) def _eval_fn(strategy): """Function for evaluator task.""" @@ -238,6 +252,7 @@ def train_and_evaluate(estimator, train_spec, eval_spec, executor_cls): local_estimator._config._eval_distribute = strategy _init_run_config_from_worker_context( local_estimator._config, dc_context.get_current_worker_context()) + logging.info('Updated config: %s', str(vars(local_estimator._config))) local_estimator._eval_distribution = strategy executor = executor_cls(local_estimator, train_spec, eval_spec) |