diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 15:19:20 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 15:28:19 -0700 |
commit | 844074c2a8e61b744c3de2718e1c9ea7b1d2edc2 (patch) | |
tree | b17a6ee138ec1d9aa69dfd9f06731b4b3871f1db /tensorflow/python/estimator | |
parent | ee9c6c17abce8450d08140750b857ad36b0508e8 (diff) |
Update hooks for distributed jobs with a master node, to ensure that
summaries are written at the correct interval for jobs with long-running
evaluations.
PiperOrigin-RevId: 214678483
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 33 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 94 |
2 files changed, 125 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index eec64ad452..fd62a79c84 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1394,6 +1394,35 @@ class Estimator(object): # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access + + # Add summary hooks to worker 0 if we are running with a master, to ensure + # that summaries are written at correct intervals even with long-running + # evaluations. + save_summary_steps = self._config.save_summary_steps + log_step_count_steps = self._config.log_step_count_steps + if run_config.TaskType.MASTER in self._config.cluster_spec.jobs: + # Update config values to prevent the default hooks from being created on + # the master or other workers. + save_summary_steps = 0 + log_step_count_steps = None + + if (self._config.task_type == run_config.TaskType.WORKER and + self._config.task_id == 0): + if (self._config.save_summary_steps and + self._config.save_summary_steps > 0): + worker_hooks.append( + training.SummarySaverHook( + save_steps=self._config.save_summary_steps, + output_dir=self._config.model_dir, + scaffold=estimator_spec.scaffold)) + + if (self._config.log_step_count_steps and + self._config.log_step_count_steps > 0): + worker_hooks.append( + training.StepCounterHook( + every_n_steps=self._config.log_step_count_steps, + output_dir=self._config.model_dir)) + with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, @@ -1403,9 +1432,9 @@ class Estimator(object): chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. - save_summaries_steps=self._config.save_summary_steps, + save_summaries_steps=save_summary_steps, config=self._session_config, - log_step_count_steps=self._config.log_step_count_steps) as mon_sess: + log_step_count_steps=log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 1ed5e30b0e..5962086aad 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import functools import glob +import json import os import tempfile @@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'): est.train(dummy_input_fn, steps=1) + def test_master_distributed_hooks(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.MASTER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + + def test_master_distributed_hooks_for_worker_0(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.WORKER, + 'index': 0 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertTrue( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertTrue( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + + def test_master_distributed_hooks_for_worker_nonzero(self): + tf_config = json.dumps({ + 'cluster': { + run_config.TaskType.PS: ['localhost:1234'], + run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'], + run_config.TaskType.MASTER: ['localhost:1236'] + }, + 'task': { + 'type': run_config.TaskType.WORKER, + 'index': 1 + } + }) + with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): + est = estimator.Estimator( + model_fn=model_fn_global_step_incrementer, + config=run_config.RunConfig()) + + with test.mock.patch.object(training, + 'MonitoredTrainingSession') as mock_sess: + est.train(dummy_input_fn, steps=1) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.SummarySaverHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertFalse( + any( + isinstance(hook, basic_session_run_hooks.StepCounterHook) + for hook in mock_sess.call_args[1]['hooks'])) + self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) + self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) + def _model_fn_with_eval_metric_ops(features, labels, mode, params): _, _ = features, labels |