diff options
author | Jianwei Xie <xiejw@google.com> | 2017-10-04 11:31:45 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-10-04 11:35:48 -0700 |
commit | af14ed3f37d52220394fb9ff902ae62fd915dbe8 (patch) | |
tree | 7a0230237b47a191175d66faf159dcab50d2b190 /tensorflow/python/estimator/training.py | |
parent | 6b90a65f6f0651464c402cd2401da488772ceb7b (diff) |
Some docstring twists and argument validations.
PiperOrigin-RevId: 171037949
Diffstat (limited to 'tensorflow/python/estimator/training.py')
-rw-r--r-- | tensorflow/python/estimator/training.py | 43 |
1 files changed, 25 insertions, 18 deletions
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index df0b602309..166b7b20ed 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -75,6 +75,7 @@ def _validate_exporters(exporters): try: for exporter in exporters: if not isinstance(exporter, exporter_lib.Exporter): + # Error message will be printed out by the outer try/except. raise TypeError if not exporter.name: @@ -83,6 +84,10 @@ def _validate_exporters(exporters): ' empty. All exporter names:' ' {}'.format(full_list_of_names)) + if not isinstance(exporter.name, six.string_types): + raise ValueError('An Exporter must have a string name. Given: ' + '{}'.format(type(exporter.name))) + if exporter.name in unique_names: full_list_of_names = [e.name for e in exporters] raise ValueError( @@ -163,7 +168,7 @@ class TrainSpec( class EvalSpec( collections.namedtuple('EvalSpec', [ 'input_fn', 'steps', 'name', 'hooks', 'exporters', - 'delay_secs', 'throttle_secs' + 'start_delay_secs', 'throttle_secs' ])): """Configuration for the "eval" part for the `train_and_evaluate` call. @@ -179,7 +184,7 @@ class EvalSpec( name=None, hooks=None, exporters=None, - delay_secs=120, + start_delay_secs=120, throttle_secs=600): """Creates a validated `EvalSpec` instance. @@ -197,7 +202,8 @@ class EvalSpec( on all workers (including chief) during training. exporters: Iterable of `Exporter`s, or a single one, or `None`. `exporters` will be invoked after each evaluation. - delay_secs: Int. Start evaluating after waiting for this many seconds. + start_delay_secs: Int. Start evaluating after waiting for this many + seconds. throttle_secs: Int. Do not re-evaluate unless the last evaluation was started at least this many seconds ago. Of course, evaluation does not occur if no new checkpoints are available, hence, this is the minimum. @@ -226,10 +232,10 @@ class EvalSpec( # Validate exporters. exporters = _validate_exporters(exporters) - # Validate delay_secs. - if delay_secs < 0: - raise ValueError( - 'Must specify delay_secs >= 0, given: {}'.format(delay_secs)) + # Validate start_delay_secs. + if start_delay_secs < 0: + raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format( + start_delay_secs)) # Validate throttle_secs. if throttle_secs < 0: @@ -243,7 +249,7 @@ class EvalSpec( name=name, hooks=hooks, exporters=exporters, - delay_secs=delay_secs, + start_delay_secs=start_delay_secs, throttle_secs=throttle_secs) @@ -606,15 +612,16 @@ class _TrainingExecutor(object): # Delay worker to start. For asynchronous training, this usually helps model # to converge faster. Chief starts the training immediately, so, worker # with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER. - delay_secs = 0 + start_delay_secs = 0 if config.task_type == run_config_lib.TaskType.WORKER: # TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in # training cluster. - delay_secs = min(_MAX_DELAY_SECS, - (config.task_id + 1) * _DELAY_SECS_PER_WORKER) - if delay_secs > 0: - logging.info('Waiting %d secs before starting training.', delay_secs) - time.sleep(delay_secs) + start_delay_secs = min(_MAX_DELAY_SECS, + (config.task_id + 1) * _DELAY_SECS_PER_WORKER) + if start_delay_secs > 0: + logging.info('Waiting %d secs before starting training.', + start_delay_secs) + time.sleep(start_delay_secs) self._estimator.train(input_fn=self._train_spec.input_fn, max_steps=self._train_spec.max_steps, @@ -623,10 +630,10 @@ class _TrainingExecutor(object): def _start_continuous_evaluation(self): """Repeatedly calls `Estimator` evaluate and export until training ends.""" - delay_secs = self._eval_spec.delay_secs - if delay_secs: - logging.info('Waiting %f secs before starting eval.', delay_secs) - time.sleep(delay_secs) + start_delay_secs = self._eval_spec.start_delay_secs + if start_delay_secs: + logging.info('Waiting %f secs before starting eval.', start_delay_secs) + time.sleep(start_delay_secs) latest_eval_result = None evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec) |