aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-28 14:10:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-28 14:36:51 -0700
commit17d73444f332490c733d37063710e72dc69d1141 (patch)
tree88de51eb19b8ada823aa833d84039845820ca15f /tensorflow/python/estimator
parentf83da5b0aa37ba55c1b2eaa093e6d043b73f5982 (diff)
Update hooks for distributed jobs with a master node, to ensure that
summaries are written at the correct interval for jobs with long-running evaluations. PiperOrigin-RevId: 214993119
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py34
-rw-r--r--tensorflow/python/estimator/estimator_test.py94
2 files changed, 126 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index b933cedb99..34faf03bb0 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -1414,6 +1414,36 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+
+ # Add summary hooks to worker 0 if we are running with a master, to ensure
+ # that summaries are written at correct intervals even with long-running
+ # evaluations.
+ save_summary_steps = self._config.save_summary_steps
+ log_step_count_steps = self._config.log_step_count_steps
+ if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+ (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
+ # Update config values to prevent the default hooks from being created on
+ # the master or other workers.
+ save_summary_steps = 0
+ log_step_count_steps = None
+
+ if (self._config.task_type == run_config.TaskType.WORKER and
+ self._config.task_id == 0):
+ if (self._config.save_summary_steps and
+ self._config.save_summary_steps > 0):
+ worker_hooks.append(
+ training.SummarySaverHook(
+ save_steps=self._config.save_summary_steps,
+ output_dir=self._config.model_dir,
+ scaffold=estimator_spec.scaffold))
+
+ if (self._config.log_step_count_steps and
+ self._config.log_step_count_steps > 0):
+ worker_hooks.append(
+ training.StepCounterHook(
+ every_n_steps=self._config.log_step_count_steps,
+ output_dir=self._config.model_dir))
+
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1423,9 +1453,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
+ save_summaries_steps=save_summary_steps,
config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index bc2504ca19..246dfb1a4b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import glob
+import json
import os
import tempfile
@@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
+ def test_master_distributed_hooks(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.MASTER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_0(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_nonzero(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 1
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels