aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/training.py
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-04 11:31:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 11:35:48 -0700
commitaf14ed3f37d52220394fb9ff902ae62fd915dbe8 (patch)
tree7a0230237b47a191175d66faf159dcab50d2b190 /tensorflow/python/estimator/training.py
parent6b90a65f6f0651464c402cd2401da488772ceb7b (diff)
Some docstring twists and argument validations.
PiperOrigin-RevId: 171037949
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r--tensorflow/python/estimator/training.py43
1 files changed, 25 insertions, 18 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index df0b602309..166b7b20ed 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -75,6 +75,7 @@ def _validate_exporters(exporters):
try:
for exporter in exporters:
if not isinstance(exporter, exporter_lib.Exporter):
+ # Error message will be printed out by the outer try/except.
raise TypeError
if not exporter.name:
@@ -83,6 +84,10 @@ def _validate_exporters(exporters):
' empty. All exporter names:'
' {}'.format(full_list_of_names))
+ if not isinstance(exporter.name, six.string_types):
+ raise ValueError('An Exporter must have a string name. Given: '
+ '{}'.format(type(exporter.name)))
+
if exporter.name in unique_names:
full_list_of_names = [e.name for e in exporters]
raise ValueError(
@@ -163,7 +168,7 @@ class TrainSpec(
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters',
- 'delay_secs', 'throttle_secs'
+ 'start_delay_secs', 'throttle_secs'
])):
"""Configuration for the "eval" part for the `train_and_evaluate` call.
@@ -179,7 +184,7 @@ class EvalSpec(
name=None,
hooks=None,
exporters=None,
- delay_secs=120,
+ start_delay_secs=120,
throttle_secs=600):
"""Creates a validated `EvalSpec` instance.
@@ -197,7 +202,8 @@ class EvalSpec(
on all workers (including chief) during training.
exporters: Iterable of `Exporter`s, or a single one, or `None`.
`exporters` will be invoked after each evaluation.
- delay_secs: Int. Start evaluating after waiting for this many seconds.
+ start_delay_secs: Int. Start evaluating after waiting for this many
+ seconds.
throttle_secs: Int. Do not re-evaluate unless the last evaluation was
started at least this many seconds ago. Of course, evaluation does not
occur if no new checkpoints are available, hence, this is the minimum.
@@ -226,10 +232,10 @@ class EvalSpec(
# Validate exporters.
exporters = _validate_exporters(exporters)
- # Validate delay_secs.
- if delay_secs < 0:
- raise ValueError(
- 'Must specify delay_secs >= 0, given: {}'.format(delay_secs))
+ # Validate start_delay_secs.
+ if start_delay_secs < 0:
+ raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format(
+ start_delay_secs))
# Validate throttle_secs.
if throttle_secs < 0:
@@ -243,7 +249,7 @@ class EvalSpec(
name=name,
hooks=hooks,
exporters=exporters,
- delay_secs=delay_secs,
+ start_delay_secs=start_delay_secs,
throttle_secs=throttle_secs)
@@ -606,15 +612,16 @@ class _TrainingExecutor(object):
# Delay worker to start. For asynchronous training, this usually helps model
# to converge faster. Chief starts the training immediately, so, worker
# with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER.
- delay_secs = 0
+ start_delay_secs = 0
if config.task_type == run_config_lib.TaskType.WORKER:
# TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in
# training cluster.
- delay_secs = min(_MAX_DELAY_SECS,
- (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
- if delay_secs > 0:
- logging.info('Waiting %d secs before starting training.', delay_secs)
- time.sleep(delay_secs)
+ start_delay_secs = min(_MAX_DELAY_SECS,
+ (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
+ if start_delay_secs > 0:
+ logging.info('Waiting %d secs before starting training.',
+ start_delay_secs)
+ time.sleep(start_delay_secs)
self._estimator.train(input_fn=self._train_spec.input_fn,
max_steps=self._train_spec.max_steps,
@@ -623,10 +630,10 @@ class _TrainingExecutor(object):
def _start_continuous_evaluation(self):
"""Repeatedly calls `Estimator` evaluate and export until training ends."""
- delay_secs = self._eval_spec.delay_secs
- if delay_secs:
- logging.info('Waiting %f secs before starting eval.', delay_secs)
- time.sleep(delay_secs)
+ start_delay_secs = self._eval_spec.start_delay_secs
+ if start_delay_secs:
+ logging.info('Waiting %f secs before starting eval.', start_delay_secs)
+ time.sleep(start_delay_secs)
latest_eval_result = None
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec)