diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-28 14:10:19 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-28 14:36:51 -0700 |
commit | 17d73444f332490c733d37063710e72dc69d1141 (patch) | |
tree | 88de51eb19b8ada823aa833d84039845820ca15f | |
parent | f83da5b0aa37ba55c1b2eaa093e6d043b73f5982 (diff) |
Update hooks for distributed jobs with a master node, to ensure that
summaries are written at the correct interval for jobs with long-running
evaluations.
PiperOrigin-RevId: 214993119
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 34 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 94 |
2 files changed, 126 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index b933cedb99..34faf03bb0 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1414,6 +1414,36 @@ class Estimator(object): # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access + + # Add summary hooks to worker 0 if we are running with a master, to ensure + # that summaries are written at correct intervals even with long-running + # evaluations. + save_summary_steps = self._config.save_summary_steps + log_step_count_steps = self._config.log_step_count_steps + if (self._config.cluster_spec and self._config.cluster_spec.jobs and + (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)): + # Update config values to prevent the default hooks from being created on + # the master or other workers. + save_summary_steps = 0 + log_step_count_steps = None + + if (self._config.task_type == run_config.TaskType.WORKER and + self._config.task_id == 0): + if (self._config.save_summary_steps and + self._config.save_summary_steps > 0): + worker_hooks.append( + training.SummarySaverHook( + save_steps=self._config.save_summary_steps, + output_dir=self._config.model_dir, + scaffold=estimator_spec.scaffold)) + + if (self._config.log_step_count_steps and + self._config.log_step_count_steps > 0): + worker_hooks.append( + training.StepCounterHook( + every_n_steps=self._config.log_step_count_steps, + output_dir=self._config.model_dir)) + with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, @@ -1423,9 +1453,9 @@ class Estimator(object): chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. - save_summaries_steps=self._config.save_summary_steps, + save_summaries_steps=save_summary_steps, config=self._session_config, - log_step_count_steps=self._config.log_step_count_steps) as mon_sess: + log_step_count_steps=log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index bc2504ca19..246dfb1a4b 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import glob +import json import os import tempfile @@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'): est.train(dummy_input_fn, steps=1) + def test_master_distributed_hooks(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + + def test_master_distributed_hooks_for_worker_0(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.WORKER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertTrue( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertTrue( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + + def test_master_distributed_hooks_for_worker_nonzero(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.WORKER, + 'index': 1 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + def _model_fn_with_eval_metric_ops(features, labels, mode, params): _, _ = features, labels |