diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-26 16:49:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-26 16:53:37 -0700 |
commit | 69650fff2b0f267162c987f35e2747be033a7d80 (patch) | |
tree | 7f8b5a90488c370c16c198b845ca0f394a882db1 /tensorflow/python/estimator | |
parent | 3b9c747d71f30c6a59f6529f8475d7f56a86a7c5 (diff) |
Automated rollback of commit 844074c2a8e61b744c3de2718e1c9ea7b1d2edc2
PiperOrigin-RevId: 214693201
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 33 | ||||
-rw-r--r-- | tensorflow/python/estimator/estimator_test.py | 94 |
2 files changed, 2 insertions, 125 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index fd62a79c84..eec64ad452 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -1394,35 +1394,6 @@ class Estimator(object): # It is expected to have one CheckpointSaverHook. If multiple, we pick # up the first one to add listener. saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access - - # Add summary hooks to worker 0 if we are running with a master, to ensure - # that summaries are written at correct intervals even with long-running - # evaluations. - save_summary_steps = self._config.save_summary_steps - log_step_count_steps = self._config.log_step_count_steps - if run_config.TaskType.MASTER in self._config.cluster_spec.jobs: - # Update config values to prevent the default hooks from being created on - # the master or other workers. - save_summary_steps = 0 - log_step_count_steps = None - - if (self._config.task_type == run_config.TaskType.WORKER and - self._config.task_id == 0): - if (self._config.save_summary_steps and - self._config.save_summary_steps > 0): - worker_hooks.append( - training.SummarySaverHook( - save_steps=self._config.save_summary_steps, - output_dir=self._config.model_dir, - scaffold=estimator_spec.scaffold)) - - if (self._config.log_step_count_steps and - self._config.log_step_count_steps > 0): - worker_hooks.append( - training.StepCounterHook( - every_n_steps=self._config.log_step_count_steps, - output_dir=self._config.model_dir)) - with training.MonitoredTrainingSession( master=self._config.master, is_chief=self._config.is_chief, @@ -1432,9 +1403,9 @@ class Estimator(object): chief_only_hooks=( tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)), save_checkpoint_secs=0, # Saving is handled by a hook. - save_summaries_steps=save_summary_steps, + save_summaries_steps=self._config.save_summary_steps, config=self._session_config, - log_step_count_steps=log_step_count_steps) as mon_sess: + log_step_count_steps=self._config.log_step_count_steps) as mon_sess: loss = None while not mon_sess.should_stop(): _, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss]) diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py index 5962086aad..1ed5e30b0e 100644 --- a/tensorflow/python/estimator/estimator_test.py +++ b/tensorflow/python/estimator/estimator_test.py @@ -20,7 +20,6 @@ from __future__ import print_function import functools import glob -import json import os import tempfile @@ -970,99 +969,6 @@ class EstimatorTrainTest(test.TestCase): with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'): est.train(dummy_input_fn, steps=1) - def test_master_distributed_hooks(self): - tf_config = json.dumps({ - 'cluster': { - run_config.TaskType.PS: ['localhost:1234'], - run_config.TaskType.WORKER: ['localhost:1235'], - run_config.TaskType.MASTER: ['localhost:1236'] - }, - 'task': { - 'type': run_config.TaskType.MASTER, - 'index': 0 - } - }) - with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): - est = estimator.Estimator( - model_fn=model_fn_global_step_incrementer, - config=run_config.RunConfig()) - - with test.mock.patch.object(training, - 'MonitoredTrainingSession') as mock_sess: - est.train(dummy_input_fn, steps=1) - self.assertFalse( - any( - isinstance(hook, basic_session_run_hooks.SummarySaverHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertFalse( - any( - isinstance(hook, basic_session_run_hooks.StepCounterHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) - self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) - - def test_master_distributed_hooks_for_worker_0(self): - tf_config = json.dumps({ - 'cluster': { - run_config.TaskType.PS: ['localhost:1234'], - run_config.TaskType.WORKER: ['localhost:1235'], - run_config.TaskType.MASTER: ['localhost:1236'] - }, - 'task': { - 'type': run_config.TaskType.WORKER, - 'index': 0 - } - }) - with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): - est = estimator.Estimator( - model_fn=model_fn_global_step_incrementer, - config=run_config.RunConfig()) - - with test.mock.patch.object(training, - 'MonitoredTrainingSession') as mock_sess: - est.train(dummy_input_fn, steps=1) - self.assertTrue( - any( - isinstance(hook, basic_session_run_hooks.SummarySaverHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertTrue( - any( - isinstance(hook, basic_session_run_hooks.StepCounterHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) - self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) - - def test_master_distributed_hooks_for_worker_nonzero(self): - tf_config = json.dumps({ - 'cluster': { - run_config.TaskType.PS: ['localhost:1234'], - run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'], - run_config.TaskType.MASTER: ['localhost:1236'] - }, - 'task': { - 'type': run_config.TaskType.WORKER, - 'index': 1 - } - }) - with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}): - est = estimator.Estimator( - model_fn=model_fn_global_step_incrementer, - config=run_config.RunConfig()) - - with test.mock.patch.object(training, - 'MonitoredTrainingSession') as mock_sess: - est.train(dummy_input_fn, steps=1) - self.assertFalse( - any( - isinstance(hook, basic_session_run_hooks.SummarySaverHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertFalse( - any( - isinstance(hook, basic_session_run_hooks.StepCounterHook) - for hook in mock_sess.call_args[1]['hooks'])) - self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps']) - self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps']) - def _model_fn_with_eval_metric_ops(features, labels, mode, params): _, _ = features, labels |