aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-04 15:13:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 15:17:39 -0700
commit89df2e336218f7f3ecf2c70f8478c64985345ded (patch)
tree2756c54c384716fe68639da5abe3e36472673a65 /tensorflow
parent4486b4f69b55633274f7903158d680bf2e9eabff (diff)
Add the 'is_the_final_export' signal to Exporters. Use them in training.
When the training ends, the final export is performed via `Exporter.export()` call. That final export is going to have is_the_final_export parameter being set to true. If `TrainSpec.max_steps` is `None`, then "when training ends" is undefined. We are going to train forever. In that case, `is_the_final_export` is going to be always False. I added a note about it. PiperOrigin-RevId: 171070760
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/python/estimator/exporter.py26
-rw-r--r--tensorflow/python/estimator/exporter_test.py41
-rw-r--r--tensorflow/python/estimator/training.py37
-rw-r--r--tensorflow/python/estimator/training_test.py81
4 files changed, 169 insertions, 16 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 505820dd93..2faca11f6e 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -40,7 +40,8 @@ class Exporter(object):
pass
@abc.abstractmethod
- def export(self, estimator, export_path, checkpoint_path, eval_result):
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
"""Exports the given `Estimator` to a specific format.
Args:
@@ -48,6 +49,13 @@ class Exporter(object):
export_path: A string containing a directory where to write the export.
checkpoint_path: The checkpoint path to export.
eval_result: The output of `Estimator.evaluate` on this checkpoint.
+ is_the_final_export: This boolean is True when this is an export in the
+ end of training. It is False for the intermediate exports during
+ the training.
+
+ When passing `Exporter` to `tf.estimator.train_and_evaluate`
+ `is_the_final_export` is always False if `TrainSpec.max_steps` is
+ `None`.
Returns:
The string path to the exported directory or `None` if export is skipped.
@@ -66,7 +74,8 @@ class LatestExporter(Exporter):
serving_input_fn,
assets_extra=None,
as_text=False,
- exports_to_keep=5):
+ exports_to_keep=5,
+ only_the_final_export=False):
"""Create an `Exporter` to use with `tf.estimator.EvalSpec`.
Args:
@@ -86,6 +95,8 @@ class LatestExporter(Exporter):
exports_to_keep: Number of exports to keep. Older exports will be
garbage-collected. Defaults to 5. Set to `None` to disable garbage
collection.
+ only_the_final_export: Only the final export in the end of training will
+ happen if this is set to True.
Raises:
ValueError: if any arguments is invalid.
@@ -95,6 +106,8 @@ class LatestExporter(Exporter):
self._assets_extra = assets_extra
self._as_text = as_text
self._exports_to_keep = exports_to_keep
+ self._only_the_final_export = only_the_final_export
+
if exports_to_keep is not None and exports_to_keep <= 0:
raise ValueError(
'`exports_to_keep`, if provided, must be positive number')
@@ -103,7 +116,14 @@ class LatestExporter(Exporter):
def name(self):
return self._name
- def export(self, estimator, export_path, checkpoint_path, eval_result):
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ if not is_the_final_export and self._only_the_final_export:
+ return None
+
+ if is_the_final_export:
+ tf_logging.info('Performing the final export in the end of training.')
+
export_result = estimator.export_savedmodel(
export_path,
self._serving_input_fn,
diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py
index 2ceff1bfd6..01582ac595 100644
--- a/tensorflow/python/estimator/exporter_test.py
+++ b/tensorflow/python/estimator/exporter_test.py
@@ -42,7 +42,7 @@ class LatestExporterTest(test.TestCase):
serving_input_fn=_serving_input_fn,
exports_to_keep=0)
- def test_saved_model_exporter(self):
+ def test_latest_exporter(self):
def _serving_input_fn():
pass
@@ -60,7 +60,42 @@ class LatestExporterTest(test.TestCase):
estimator.export_savedmodel.return_value = "export_result_path"
export_result = exporter.export(estimator, export_dir_base,
- "checkpoint_path", {})
+ "checkpoint_path", {}, False)
+
+ self.assertEqual("export_result_path", export_result)
+ estimator.export_savedmodel.assert_called_with(
+ export_dir_base,
+ _serving_input_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ checkpoint_path="checkpoint_path")
+
+ def test_only_the_last_export_is_saved(self):
+
+ def _serving_input_fn():
+ pass
+
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ gfile.MkDir(export_dir_base)
+
+ exporter = exporter_lib.LatestExporter(
+ name="latest_exporter",
+ serving_input_fn=_serving_input_fn,
+ assets_extra={"from/path": "to/path"},
+ as_text=False,
+ exports_to_keep=5,
+ only_the_final_export=True)
+ estimator = test.mock.Mock(spec=estimator_lib.Estimator)
+ estimator.export_savedmodel.return_value = "export_result_path"
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {}, False)
+
+ self.assertFalse(estimator.export_savedmodel.called)
+ self.assertEqual(None, export_result)
+
+ export_result = exporter.export(estimator, export_dir_base,
+ "checkpoint_path", {}, True)
self.assertEqual("export_result_path", export_result)
estimator.export_savedmodel.assert_called_with(
@@ -93,7 +128,7 @@ class LatestExporterTest(test.TestCase):
estimator = test.mock.Mock(spec=estimator_lib.Estimator)
# Garbage collect all but the most recent 2 exports,
# where recency is determined based on the timestamp directory names.
- exporter.export(estimator, export_dir_base, None, None)
+ exporter.export(estimator, export_dir_base, None, None, False)
self.assertFalse(gfile.Exists(export_dir_1))
self.assertFalse(gfile.Exists(export_dir_2))
diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py
index 1bed19760b..0a558a67b9 100644
--- a/tensorflow/python/estimator/training.py
+++ b/tensorflow/python/estimator/training.py
@@ -519,8 +519,11 @@ class _TrainingExecutor(object):
class NewCheckpointListener(
basic_session_run_hooks.CheckpointSaverListener):
- def __init__(self, estimator, eval_spec):
- self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec) # pylint: disable=protected-access
+ def __init__(self, estimator, eval_spec, max_training_steps):
+ # pylint: disable=protected-access
+ self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec,
+ max_training_steps)
+ # pylint: enable=protected-access
def after_save(self, session, global_step_value):
del session, global_step_value
@@ -528,8 +531,10 @@ class _TrainingExecutor(object):
# When the underlying `Estimator` object saves a new checkpoint, we would
# like this callback to be called so that evaluation and export can trigger.
- saving_listeners = [NewCheckpointListener(self._estimator, self._eval_spec)]
-
+ saving_listeners = [
+ NewCheckpointListener(self._estimator, self._eval_spec,
+ self._train_spec.max_steps)
+ ]
return self._start_distributed_training(saving_listeners=saving_listeners)
def run_evaluator(self):
@@ -566,7 +571,8 @@ class _TrainingExecutor(object):
'after {} secs (eval_spec.throttle_secs) or training is '
'finished.'.format(self._eval_spec.throttle_secs))
- evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec)
+ evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
+ self._train_spec.max_steps)
while True:
self._estimator.train(
@@ -636,7 +642,8 @@ class _TrainingExecutor(object):
time.sleep(start_delay_secs)
latest_eval_result = None
- evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec)
+ evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec,
+ self._train_spec.max_steps)
while True:
if latest_eval_result:
@@ -663,11 +670,12 @@ class _TrainingExecutor(object):
class _Evaluator(object):
"""A helper class to call `Estimator.evaluate` and export model."""
- def __init__(self, estimator, eval_spec):
+ def __init__(self, estimator, eval_spec, max_training_steps):
self._estimator = estimator
self._eval_spec = eval_spec
self._previous_ckpt_path = None
self._last_warning_time = 0
+ self._max_training_steps = max_training_steps
def evaluate_and_export(self):
"""Evaluate and (maybe) export the current model.
@@ -712,7 +720,14 @@ class _TrainingExecutor(object):
'Internal error: `Estimator.evaluate` result should have '
'`global_step` in result. Given {}'.format(eval_result))
- self._export_eval_result(eval_result, latest_ckpt_path)
+ # TODO(isaprykin): There is a potential race condition here in the
+ # distributed setting. The worker job that performs training
+ # might stop at a later global step value than the evalutor job.
+ is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >=
+ self._max_training_steps
+ if self._max_training_steps else False)
+ self._export_eval_result(eval_result, latest_ckpt_path,
+ is_the_final_export)
self._last_warning_time = 0
self._previous_ckpt_path = latest_ckpt_path
@@ -725,7 +740,8 @@ class _TrainingExecutor(object):
logging.warning(message)
self._last_warning_time = current_time
- def _export_eval_result(self, eval_result, checkpoint_path):
+ def _export_eval_result(self, eval_result, checkpoint_path,
+ is_the_final_export):
"""Export `eval_result` according to exporters in `EvalSpec`."""
export_dir_base = os.path.join(
compat.as_str_any(self._estimator.model_dir),
@@ -738,4 +754,5 @@ class _TrainingExecutor(object):
compat.as_str_any(export_dir_base),
compat.as_str_any(exporter.name)),
checkpoint_path=checkpoint_path,
- eval_result=eval_result)
+ eval_result=eval_result,
+ is_the_final_export=is_the_final_export)
diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py
index e4c400ca7f..08d11d7d25 100644
--- a/tensorflow/python/estimator/training_test.py
+++ b/tensorflow/python/estimator/training_test.py
@@ -802,6 +802,46 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase):
self.assertEqual(2, mock_est.evaluate.call_count)
self.assertEqual(2, exporter.export.call_count)
+ def test_final_export_is_true_in_the_end(self):
+ training_max_step = 200
+
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
+ mock_est.model_dir = compat.as_bytes(test.get_temp_dir())
+ mock_est.evaluate.side_effect = [
+ {_GLOBAL_STEP_KEY: training_max_step // 2},
+ {_GLOBAL_STEP_KEY: training_max_step}
+ ]
+ mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2']
+
+ mock_train_spec = test.mock.Mock(spec=training.TrainSpec)
+ mock_train_spec.max_steps = training_max_step
+
+ mock_est.times_export_fn_was_called = 0
+ mock_est.times_the_final_export_was_true = 0
+ def export(estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ del export_path, checkpoint_path, eval_result
+ estimator.times_export_fn_was_called += 1
+ if is_the_final_export:
+ estimator.times_the_final_export_was_true += 1
+
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_how_many_times_export_is_called'
+ exporter.export = export
+
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1,
+ start_delay_secs=0,
+ throttle_secs=0,
+ exporters=exporter)
+
+ executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec)
+ executor.run_evaluator()
+
+ self.assertEqual(2, mock_est.evaluate.call_count)
+ self.assertEqual(2, mock_est.times_export_fn_was_called)
+ self.assertEqual(1, mock_est.times_the_final_export_was_true)
+
def test_skip_evaluation_due_to_ckpt(self):
training_max_step = 200
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
@@ -1134,6 +1174,47 @@ class TrainingExecutorRunLocalTest(test.TestCase):
with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG):
executor.run_local()
+ def test_final_export_is_true_in_the_end(self):
+ mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
+ mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn
+
+ mock_est.times_export_fn_was_called = 0
+ mock_est.times_the_final_export_was_true = 0
+ def export(estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ del export_path, checkpoint_path, eval_result
+ estimator.times_export_fn_was_called += 1
+ if is_the_final_export:
+ estimator.times_the_final_export_was_true += 1
+
+ exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter)
+ exporter.name = 'see_how_many_times_export_is_called'
+ exporter.export = export
+
+ train_spec = training.TrainSpec(
+ input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()])
+ eval_spec = training.EvalSpec(
+ input_fn=lambda: 1,
+ hooks=[_FakeHook()],
+ throttle_secs=100,
+ exporters=exporter)
+ # should be called 3 times.
+ mock_est.evaluate.side_effect = [{
+ _GLOBAL_STEP_KEY: train_spec.max_steps - 100
+ }, {
+ _GLOBAL_STEP_KEY: train_spec.max_steps - 50
+ }, {
+ _GLOBAL_STEP_KEY: train_spec.max_steps
+ }]
+
+ executor = training._TrainingExecutor(mock_est, train_spec, eval_spec)
+ executor.run_local()
+
+ self.assertEqual(3, mock_est.train.call_count)
+ self.assertEqual(3, mock_est.evaluate.call_count)
+ self.assertEqual(3, mock_est.times_export_fn_was_called)
+ self.assertEqual(1, mock_est.times_the_final_export_was_true)
+
def test_train_and_evaluate_args(self):
mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/')
mock_est.latest_checkpoint.return_value = 'checkpoint_path/'