aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-26 16:49:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 16:53:37 -0700
commit69650fff2b0f267162c987f35e2747be033a7d80 (patch)
tree7f8b5a90488c370c16c198b845ca0f394a882db1 /tensorflow/python/estimator
parent3b9c747d71f30c6a59f6529f8475d7f56a86a7c5 (diff)
Automated rollback of commit 844074c2a8e61b744c3de2718e1c9ea7b1d2edc2
PiperOrigin-RevId: 214693201
Diffstat (limited to 'tensorflow/python/estimator')
-rw-r--r--tensorflow/python/estimator/estimator.py33
-rw-r--r--tensorflow/python/estimator/estimator_test.py94
2 files changed, 2 insertions, 125 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index fd62a79c84..eec64ad452 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -1394,35 +1394,6 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
-
- # Add summary hooks to worker 0 if we are running with a master, to ensure
- # that summaries are written at correct intervals even with long-running
- # evaluations.
- save_summary_steps = self._config.save_summary_steps
- log_step_count_steps = self._config.log_step_count_steps
- if run_config.TaskType.MASTER in self._config.cluster_spec.jobs:
- # Update config values to prevent the default hooks from being created on
- # the master or other workers.
- save_summary_steps = 0
- log_step_count_steps = None
-
- if (self._config.task_type == run_config.TaskType.WORKER and
- self._config.task_id == 0):
- if (self._config.save_summary_steps and
- self._config.save_summary_steps > 0):
- worker_hooks.append(
- training.SummarySaverHook(
- save_steps=self._config.save_summary_steps,
- output_dir=self._config.model_dir,
- scaffold=estimator_spec.scaffold))
-
- if (self._config.log_step_count_steps and
- self._config.log_step_count_steps > 0):
- worker_hooks.append(
- training.StepCounterHook(
- every_n_steps=self._config.log_step_count_steps,
- output_dir=self._config.model_dir))
-
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1432,9 +1403,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=save_summary_steps,
+ save_summaries_steps=self._config.save_summary_steps,
config=self._session_config,
- log_step_count_steps=log_step_count_steps) as mon_sess:
+ log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 5962086aad..1ed5e30b0e 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import functools
import glob
-import json
import os
import tempfile
@@ -970,99 +969,6 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
- def test_master_distributed_hooks(self):
- tf_config = json.dumps({
- 'cluster': {
- run_config.TaskType.PS: ['localhost:1234'],
- run_config.TaskType.WORKER: ['localhost:1235'],
- run_config.TaskType.MASTER: ['localhost:1236']
- },
- 'task': {
- 'type': run_config.TaskType.MASTER,
- 'index': 0
- }
- })
- with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- est = estimator.Estimator(
- model_fn=model_fn_global_step_incrementer,
- config=run_config.RunConfig())
-
- with test.mock.patch.object(training,
- 'MonitoredTrainingSession') as mock_sess:
- est.train(dummy_input_fn, steps=1)
- self.assertFalse(
- any(
- isinstance(hook, basic_session_run_hooks.SummarySaverHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertFalse(
- any(
- isinstance(hook, basic_session_run_hooks.StepCounterHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
- self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
-
- def test_master_distributed_hooks_for_worker_0(self):
- tf_config = json.dumps({
- 'cluster': {
- run_config.TaskType.PS: ['localhost:1234'],
- run_config.TaskType.WORKER: ['localhost:1235'],
- run_config.TaskType.MASTER: ['localhost:1236']
- },
- 'task': {
- 'type': run_config.TaskType.WORKER,
- 'index': 0
- }
- })
- with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- est = estimator.Estimator(
- model_fn=model_fn_global_step_incrementer,
- config=run_config.RunConfig())
-
- with test.mock.patch.object(training,
- 'MonitoredTrainingSession') as mock_sess:
- est.train(dummy_input_fn, steps=1)
- self.assertTrue(
- any(
- isinstance(hook, basic_session_run_hooks.SummarySaverHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertTrue(
- any(
- isinstance(hook, basic_session_run_hooks.StepCounterHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
- self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
-
- def test_master_distributed_hooks_for_worker_nonzero(self):
- tf_config = json.dumps({
- 'cluster': {
- run_config.TaskType.PS: ['localhost:1234'],
- run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
- run_config.TaskType.MASTER: ['localhost:1236']
- },
- 'task': {
- 'type': run_config.TaskType.WORKER,
- 'index': 1
- }
- })
- with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
- est = estimator.Estimator(
- model_fn=model_fn_global_step_incrementer,
- config=run_config.RunConfig())
-
- with test.mock.patch.object(training,
- 'MonitoredTrainingSession') as mock_sess:
- est.train(dummy_input_fn, steps=1)
- self.assertFalse(
- any(
- isinstance(hook, basic_session_run_hooks.SummarySaverHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertFalse(
- any(
- isinstance(hook, basic_session_run_hooks.StepCounterHook)
- for hook in mock_sess.call_args[1]['hooks']))
- self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
- self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
-
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels