aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-29 10:42:32 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-29 10:46:04 -0800
commitfd63d4e30a01cf860baf60b990b223cd54bc895c (patch)
treefcea79b1e89bcf30ac80d087edf051c3711d06b1 /tensorflow/python/estimator/training.py
parent730071d0dca35a9e08f3bdc49661ae34d109da74 (diff)
Add C0326 bad-whitespace error to pylint sanity check.
PiperOrigin-RevId: 183689499
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py56
1 files changed, 22 insertions, 34 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 52fb1d39ae..2e84c5014f 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Classes and functions related to train_and_evaluate."""
from __future__ import absolute_import
@@ -37,7 +36,6 @@ from tensorflow.python.training import server_lib
from tensorflow.python.training import session_run_hook
from tensorflow.python.util import compat
-
_MAX_DELAY_SECS = 60
_DELAY_SECS_PER_WORKER = 5
_TF_CONFIG_ENV = 'TF_CONFIG'
@@ -50,8 +48,7 @@ _TRAINER_JOBS = (run_config_lib.TaskType.CHIEF, run_config_lib.TaskType.MASTER,
def _validate_input_fn(input_fn):
"""Validates the `input_fn`."""
if not callable(input_fn):
- raise TypeError(
- '`input_fn` must be callable, given: {}'.format(input_fn))
+ raise TypeError('`input_fn` must be callable, given: {}'.format(input_fn))
def _validate_hooks(hooks):
@@ -125,10 +122,7 @@ class TrainSpec(
duration. Optional hooks run at various stages of training.
"""
- def __new__(cls,
- input_fn,
- max_steps=None,
- hooks=None):
+ def __new__(cls, input_fn, max_steps=None, hooks=None):
"""Creates a validated `TrainSpec` instance.
Args:
@@ -161,16 +155,13 @@ class TrainSpec(
hooks = _validate_hooks(hooks)
return super(TrainSpec, cls).__new__(
- cls,
- input_fn=input_fn,
- max_steps=max_steps,
- hooks=hooks)
+ cls, input_fn=input_fn, max_steps=max_steps, hooks=hooks)
class EvalSpec(
collections.namedtuple('EvalSpec', [
- 'input_fn', 'steps', 'name', 'hooks', 'exporters',
- 'start_delay_secs', 'throttle_secs'
+ 'input_fn', 'steps', 'name', 'hooks', 'exporters', 'start_delay_secs',
+ 'throttle_secs'
])):
"""Configuration for the "eval" part for the `train_and_evaluate` call.
@@ -417,8 +408,8 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
Raises:
ValueError: if environment variable `TF_CONFIG` is incorrectly set.
"""
- executor = _TrainingExecutor(estimator=estimator, train_spec=train_spec,
- eval_spec=eval_spec)
+ executor = _TrainingExecutor(
+ estimator=estimator, train_spec=train_spec, eval_spec=eval_spec)
config = estimator.config
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
@@ -561,9 +552,8 @@ class _TrainingExecutor(object):
self._timer.update_last_triggered_step(global_step_value)
self._evaluator.evaluate_and_export()
else:
- logging.info(
- 'Skip the current checkpoint eval due to throttle secs '
- '({} secs).'.format(self._eval_throttle_secs))
+ logging.info('Skip the current checkpoint eval due to throttle secs '
+ '({} secs).'.format(self._eval_throttle_secs))
# Final export signal: For any eval result with global_step >= train
# max_steps, the evaluator will send the final export signal. There is a
@@ -576,8 +566,8 @@ class _TrainingExecutor(object):
#
# But here, throttle_secs will skip the next intermediate checkpoint and,
# so, the double final export chance is very small.
- evaluator = _TrainingExecutor._Evaluator(
- self._estimator, self._eval_spec, self._train_spec.max_steps)
+ evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
+ self._train_spec.max_steps)
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
@@ -617,8 +607,7 @@ class _TrainingExecutor(object):
raise ValueError('eval_spec.throttle_secs should be positive, given: {}.'
'It is used do determine how long each training '
'iteration should go when train and evaluate '
- 'locally.'.format(
- self._eval_spec.throttle_secs))
+ 'locally.'.format(self._eval_spec.throttle_secs))
stop_hook = _StopAtSecsHook(self._eval_spec.throttle_secs)
train_hooks = (
@@ -663,8 +652,9 @@ class _TrainingExecutor(object):
if not config.master:
jobs = config.cluster_spec.jobs
- if (len(jobs) == 1 and len(config.cluster_spec.job_tasks(jobs[0])) == 1
- and config.task_type in _TRAINER_JOBS):
+ if (len(jobs) == 1 and
+ len(config.cluster_spec.job_tasks(jobs[0])) == 1 and
+ config.task_type in _TRAINER_JOBS):
# For distributed training, config.master is empty if and only if it has
# a single node in the cluster spec. In this case, we should not start
# the server.
@@ -679,9 +669,9 @@ class _TrainingExecutor(object):
logging.info('Start Tensorflow server.')
if config.session_config is None:
- session_config=config_pb2.ConfigProto(log_device_placement=False)
+ session_config = config_pb2.ConfigProto(log_device_placement=False)
else:
- session_config=config_pb2.ConfigProto(
+ session_config = config_pb2.ConfigProto(
log_device_placement=False,
gpu_options=config.session_config.gpu_options)
@@ -744,8 +734,7 @@ class _TrainingExecutor(object):
global_step >= self._train_spec.max_steps):
logging.info(
'Exiting evaluation, global_step=%s >= train max_steps=%s',
- global_step,
- self._train_spec.max_steps)
+ global_step, self._train_spec.max_steps)
return
latest_eval_result, should_early_stop = self._execute_evaluator_once(
@@ -781,10 +770,9 @@ class _TrainingExecutor(object):
# Throttle if necessary.
elapsed_time = time.time() - start
- difference = throttle_secs - elapsed_time
+ difference = throttle_secs - elapsed_time
if difference > 0:
- logging.info('Waiting %f secs before starting next eval run.',
- difference)
+ logging.info('Waiting %f secs before starting next eval run.', difference)
time.sleep(difference)
return (eval_result, should_early_stop)
@@ -929,8 +917,8 @@ class _EvalResult(
if checkpoint_path:
raise ValueError(
'checkpoint must be `None` if status is not {}; got status {}, '
- 'checkpoint_path {}'.format(
- _EvalStatus.EVALUATED, status, checkpoint_path))
+ 'checkpoint_path {}'.format(_EvalStatus.EVALUATED, status,
+ checkpoint_path))
return super(_EvalResult, cls).__new__(cls, status, metrics,
checkpoint_path)