aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
authorGravatar Jianwei Xie <xiejw@google.com>2017-10-04 11:31:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 11:35:48 -0700
commitaf14ed3f37d52220394fb9ff902ae62fd915dbe8 (patch)
tree7a0230237b47a191175d66faf159dcab50d2b190 /tensorflow/python
parent6b90a65f6f0651464c402cd2401da488772ceb7b (diff)
Some docstring twists and argument validations.
PiperOrigin-RevId: 171037949
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/estimator/exporter.py29
-rw-r--r--tensorflow/python/estimator/exporter_test.py9
-rw-r--r--tensorflow/python/estimator/training.py43
-rw-r--r--tensorflow/python/estimator/training_test.py75
4 files changed, 91 insertions, 65 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 62dcbd894b..621dece119 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -35,7 +35,7 @@ class Exporter(object):
"""Directory name.
A directory name under the export base directory where exports of
- this type are written. Should not be `None`.
+ this type are written. Should not be `None` nor empty.
"""
pass
@@ -58,7 +58,7 @@ class Exporter(object):
class SavedModelExporter(Exporter):
"""This class exports the serving graph and checkpoints.
- In addition, the class also garbage collects stale exports.
+ In addition, the class also garbage collects stale exports.
"""
def __init__(self,
@@ -74,23 +74,30 @@ class SavedModelExporter(Exporter):
export path.
serving_input_fn: a function that takes no arguments and returns an
`ServingInputReceiver`.
- assets_extra: A dict specifying how to populate the assets.extra directory
- within the exported SavedModel. Each key should give the destination
- path (including the filename) relative to the assets.extra directory.
- The corresponding value gives the full path of the source file to be
- copied. For example, the simple case of copying a single file without
- renaming it is specified as
+ assets_extra: An optional dict specifying how to populate the assets.extra
+ directory within the exported SavedModel. Each key should give the
+ destination path (including the filename) relative to the assets.extra
+ directory. The corresponding value gives the full path of the source
+ file to be copied. For example, the simple case of copying a single
+ file without renaming it is specified as
`{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
- as_text: whether to write the SavedModel proto in text format.
+ as_text: whether to write the SavedModel proto in text format. Defaults to
+ `False`.
exports_to_keep: Number of exports to keep. Older exports will be
- garbage-collected. Defaults to 5. Set to None to disable garbage
+ garbage-collected. Defaults to 5. Set to `None` to disable garbage
collection.
+
+ Raises:
+ ValueError: if any arguments is invalid.
"""
self._name = name
self._serving_input_fn = serving_input_fn
self._assets_extra = assets_extra
self._as_text = as_text
self._exports_to_keep = exports_to_keep
+ if exports_to_keep is not None and exports_to_keep <= 0:
+ raise ValueError(
+ '`exports_to_keep`, if provided, must be positive number')
@property
def name(self):
@@ -127,6 +134,7 @@ class SavedModelExporter(Exporter):
return None
return path._replace(export_version=int(filename))
+ # pylint: disable=protected-access
keep_filter = gc._largest_export_versions(self._exports_to_keep)
delete_filter = gc._negation(keep_filter)
for p in delete_filter(
@@ -135,3 +143,4 @@ class SavedModelExporter(Exporter):
gfile.DeleteRecursively(p.path)
except errors_impl.NotFoundError as e:
tf_logging.warn('Can not delete %s recursively: %s', p.path, e)
+ # pylint: enable=protected-access
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index 4d09467f10..106202c9c2 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -32,6 +32,15 @@ from tensorflow.python.util import compat
class SavedModelExporterTest(test.TestCase):
+ def test_error_out_if_exports_to_keep_is_zero(self):
+ def _serving_input_fn():
+ pass
+ with self.assertRaisesRegexp(ValueError, "positive number"):
+ exporter_lib.SavedModelExporter(
+ name="saved_model_exporter",
+ serving_input_fn=_serving_input_fn,
+ exports_to_keep=0)
+
def test_saved_model_exporter(self):
def _serving_input_fn():
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index df0b602309..166b7b20ed 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -75,6 +75,7 @@ def _validate_exporters(exporters):
try:
for exporter in exporters:
if not isinstance(exporter, exporter_lib.Exporter):
+ # Error message will be printed out by the outer try/except.
raise TypeError
if not exporter.name:
@@ -83,6 +84,10 @@ def _validate_exporters(exporters):
' empty. All exporter names:'
' {}'.format(full_list_of_names))
+ if not isinstance(exporter.name, six.string_types):
+ raise ValueError('An Exporter must have a string name. Given: '
+ '{}'.format(type(exporter.name)))
+
if exporter.name in unique_names:
full_list_of_names = [e.name for e in exporters]
raise ValueError(
@@ -163,7 +168,7 @@ class TrainSpec(
class EvalSpec(
collections.namedtuple('EvalSpec', [
'input_fn', 'steps', 'name', 'hooks', 'exporters',
- 'delay_secs', 'throttle_secs'
+ 'start_delay_secs', 'throttle_secs'
])):
"""Configuration for the "eval" part for the `train_and_evaluate` call.
@@ -179,7 +184,7 @@ class EvalSpec(
name=None,
hooks=None,
exporters=None,
- delay_secs=120,
+ start_delay_secs=120,
throttle_secs=600):
"""Creates a validated `EvalSpec` instance.
@@ -197,7 +202,8 @@ class EvalSpec(
on all workers (including chief) during training.
exporters: Iterable of `Exporter`s, or a single one, or `None`.
`exporters` will be invoked after each evaluation.
- delay_secs: Int. Start evaluating after waiting for this many seconds.
+ start_delay_secs: Int. Start evaluating after waiting for this many
+ seconds.
throttle_secs: Int. Do not re-evaluate unless the last evaluation was
started at least this many seconds ago. Of course, evaluation does not
occur if no new checkpoints are available, hence, this is the minimum.
@@ -226,10 +232,10 @@ class EvalSpec(
# Validate exporters.
exporters = _validate_exporters(exporters)
- # Validate delay_secs.
- if delay_secs < 0:
- raise ValueError(
- 'Must specify delay_secs >= 0, given: {}'.format(delay_secs))
+ # Validate start_delay_secs.
+ if start_delay_secs < 0:
+ raise ValueError('Must specify start_delay_secs >= 0, given: {}'.format(
+ start_delay_secs))
# Validate throttle_secs.
if throttle_secs < 0:
@@ -243,7 +249,7 @@ class EvalSpec(
name=name,
hooks=hooks,
exporters=exporters,
- delay_secs=delay_secs,
+ start_delay_secs=start_delay_secs,
throttle_secs=throttle_secs)
@@ -606,15 +612,16 @@ class _TrainingExecutor(object):
# Delay worker to start. For asynchronous training, this usually helps model
# to converge faster. Chief starts the training immediately, so, worker
# with task id x (0-based) should wait (x+1) * _DELAY_SECS_PER_WORKER.
- delay_secs = 0
+ start_delay_secs = 0
if config.task_type == run_config_lib.TaskType.WORKER:
# TODO(xiejw): Replace the hard code logic (task_id + 1) with unique id in
# training cluster.
- delay_secs = min(_MAX_DELAY_SECS,
- (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
- if delay_secs > 0:
- logging.info('Waiting %d secs before starting training.', delay_secs)
- time.sleep(delay_secs)
+ start_delay_secs = min(_MAX_DELAY_SECS,
+ (config.task_id + 1) * _DELAY_SECS_PER_WORKER)
+ if start_delay_secs > 0:
+ logging.info('Waiting %d secs before starting training.',
+ start_delay_secs)
+ time.sleep(start_delay_secs)
self._estimator.train(input_fn=self._train_spec.input_fn,
max_steps=self._train_spec.max_steps,
@@ -623,10 +630,10 @@ class _TrainingExecutor(object):
def _start_continuous_evaluation(self):
"""Repeatedly calls `Estimator` evaluate and export until training ends."""
- delay_secs = self._eval_spec.delay_secs
- if delay_secs:
- logging.info('Waiting %f secs before starting eval.', delay_secs)
- time.sleep(delay_secs)
+ start_delay_secs = self._eval_spec.start_delay_secs
+ if start_delay_secs:
+ logging.info('Waiting %f secs before starting eval.', start_delay_secs)
+ time.sleep(start_delay_secs)
latest_eval_result = None
evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index 5d6b01b7f0..c474004dab 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -47,11 +47,12 @@ _INVALID_HOOK_MSG = 'All hooks must be `SessionRunHook` instances'
_INVALID_MAX_STEPS_MSG = 'Must specify max_steps > 0'
_INVALID_STEPS_MSG = 'Must specify steps > 0'
_INVALID_NAME_MSG = '`name` must be string'
-_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify delay_secs >= 0'
+_INVALID_EVAL_DELAY_SECS_MSG = 'Must specify start_delay_secs >= 0'
_INVALID_EVAL_THROTTLE_SECS_MSG = 'Must specify throttle_secs >= 0'
_INVALID_ESTIMATOR_MSG = '`estimator` must have type `tf.estimator.Estimator`'
_STALE_CHECKPOINT_MSG = 'There was no new checkpoint after the training.'
_INVALID_EXPORTER_MSG = '`exporters` must be an Exporter'
+_INVALID_EXPORTER_NAME_TYPE_MSG = 'An Exporter must have a string name'
_DUPLICATE_EXPORTER_NAMES_MSG = '`exporters` must have unique names.'
_NONE_EXPORTER_NAME_MSG = (
'An Exporter cannot have a name that is `None` or empty.')
@@ -205,7 +206,7 @@ class EvalSpecTest(test.TestCase):
self.assertIsNone(spec.name)
self.assertEqual(0, len(spec.hooks))
self.assertEqual(0, len(spec.exporters))
- self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.delay_secs)
+ self.assertEqual(_DEFAULT_EVAL_DELAY_SECS, spec.start_delay_secs)
self.assertEqual(_DEFAULT_EVAL_THROTTLE_SECS, spec.throttle_secs)
def testAllArgumentsSet(self):
@@ -219,14 +220,14 @@ class EvalSpecTest(test.TestCase):
name='name',
hooks=hooks,
exporters=exporter,
- delay_secs=3,
+ start_delay_secs=3,
throttle_secs=4)
self.assertEqual(1, spec.input_fn())
self.assertEqual(2, spec.steps)
self.assertEqual('name', spec.name)
self.assertEqual(tuple(hooks), spec.hooks)
self.assertEqual((exporter,), spec.exporters)
- self.assertEqual(3, spec.delay_secs)
+ self.assertEqual(3, spec.start_delay_secs)
self.assertEqual(4, spec.throttle_secs)
def testListOfExporters(self):
@@ -255,7 +256,7 @@ class EvalSpecTest(test.TestCase):
def testInvalidDelaySecs(self):
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_DELAY_SECS_MSG):
- training.EvalSpec(input_fn=lambda: 1, delay_secs=-1)
+ training.EvalSpec(input_fn=lambda: 1, start_delay_secs=-1)
def testInvalidThrottleSecs(self):
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_THROTTLE_SECS_MSG):
@@ -271,6 +272,11 @@ class EvalSpecTest(test.TestCase):
with self.assertRaisesRegexp(TypeError, _INVALID_EXPORTER_MSG):
training.EvalSpec(input_fn=lambda: 1, exporters=_FakeHook())
+ def testInvalidTypeOfExporterName(self):
+ with self.assertRaisesRegexp(ValueError, _INVALID_EXPORTER_NAME_TYPE_MSG):
+ training.EvalSpec(input_fn=lambda: 1,
+ exporters=_create_exporter(name=123))
+
def testMultipleExportersWithTheSameName(self):
with self.assertRaisesRegexp(ValueError, _DUPLICATE_EXPORTER_NAMES_MSG):
training.EvalSpec(
@@ -699,10 +705,9 @@ class TrainingExecutorRunMasterTest(_TrainingExecutorTrainingTest,
del args, kwargs
estimator.export_was_called = True
- exporter = test.mock.Mock(
- spec=exporter_lib.Exporter,
- name='see_whether_export_is_called',
- export=export)
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_whether_export_is_called'
+ exporter.export = export
train_spec = training.TrainSpec(input_fn=lambda: 1, max_steps=300)
eval_spec = training.EvalSpec(
@@ -739,7 +744,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
- delay_secs=0, throttle_secs=0)
+ start_delay_secs=0, throttle_secs=0)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
executor.run_evaluator()
@@ -766,13 +771,12 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
mock_train_spec.max_steps = training_max_step
- exporter = test.mock.Mock(
- spec=exporter_lib.Exporter,
- name='see_how_many_times_export_is_called')
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_how_many_times_export_is_called'
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
- delay_secs=0,
+ start_delay_secs=0,
throttle_secs=0,
exporters=exporter)
@@ -800,7 +804,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
]
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, delay_secs=0, throttle_secs=0)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=0)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(logging, 'warning') as mock_log:
@@ -814,9 +818,9 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
# successuful evaluation)
self.assertEqual(2, mock_log.call_count)
- def test_sleep_delay_secs(self):
+ def test_sleep_start_delay_secs(self):
training_max_step = 200
- delay_secs = 123
+ start_delay_secs = 123
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
mock_est.evaluate.return_value = {_GLOBAL_STEP_KEY: training_max_step}
@@ -826,12 +830,12 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
eval_spec = training.EvalSpec(
input_fn=lambda: 1, steps=2, hooks=[_FakeHook()], name='cont_eval',
- delay_secs=delay_secs, throttle_secs=0)
+ start_delay_secs=start_delay_secs, throttle_secs=0)
executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
with test.mock.patch.object(time, 'sleep') as mock_sleep:
executor.run_evaluator()
- mock_sleep.assert_called_with(delay_secs)
+ mock_sleep.assert_called_with(start_delay_secs)
self.assertTrue(mock_est.evaluate.called)
@test.mock.patch.object(time, 'time')
@@ -845,7 +849,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
self._set_up_mock_est_to_train_and_evaluate_once(mock_est, mock_train_spec)
eval_spec = training.EvalSpec(
- input_fn=lambda: 1, delay_secs=0, throttle_secs=throttle_secs)
+ input_fn=lambda: 1, start_delay_secs=0, throttle_secs=throttle_secs)
mock_time.side_effect = [921, 921 + operation_secs]
@@ -865,15 +869,14 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
del args, kwargs
estimator.export_was_called = True
- exporter = test.mock.Mock(
- spec=exporter_lib.Exporter,
- name='see_whether_export_is_called',
- export=export)
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_whether_export_is_called'
+ exporter.export = export
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
steps=2,
- delay_secs=0,
+ start_delay_secs=0,
throttle_secs=0,
exporters=exporter)
@@ -887,7 +890,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=(lambda: 1),
- delay_secs=0, throttle_secs=0)
+ start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = {}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
@@ -898,7 +901,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=(lambda: 1),
- delay_secs=0, throttle_secs=0)
+ start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = 123
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
@@ -909,7 +912,7 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
train_spec = training.TrainSpec(input_fn=lambda: 1)
eval_spec = training.EvalSpec(input_fn=(lambda: 1),
- delay_secs=0, throttle_secs=0)
+ start_delay_secs=0, throttle_secs=0)
mock_est.evaluate.return_value = {'loss': 123}
executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
@@ -1067,10 +1070,9 @@ class TrainingExecutorRunLocalTest(test.TestCase):
del args, kwargs
estimator.times_export_was_called += 1
- exporter = test.mock.Mock(
- spec=exporter_lib.Exporter,
- name='see_how_many_times_export_is_called',
- export=export)
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_how_many_times_export_is_called'
+ exporter.export = export
train_spec = training.TrainSpec(
input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
@@ -1164,15 +1166,14 @@ class TrainingExecutorRunLocalTest(test.TestCase):
del args, kwargs
estimator.export_was_called = True
- exporter = test.mock.Mock(
- spec=exporter_lib.Exporter,
- name='see_whether_export_is_called',
- export=export)
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_whether_export_is_called'
+ exporter.export = export
eval_spec = training.EvalSpec(
input_fn=lambda: 1,
steps=2,
- delay_secs=0,
+ start_delay_secs=0,
throttle_secs=213,
exporters=exporter)