diff options
author | 2017-10-04 15:13:33 -0700 | |
---|---|---|
committer | 2017-10-04 15:17:39 -0700 | |
commit | 89df2e336218f7f3ecf2c70f8478c64985345ded (patch) | |
tree | 2756c54c384716fe68639da5abe3e36472673a65 /tensorflow | |
parent | 4486b4f69b55633274f7903158d680bf2e9eabff (diff) |
Add the 'is_the_final_export' signal to Exporters. Use them in training.
When the training ends, the final export is performed via `Exporter.export()` call. That final export is going to have is_the_final_export parameter being set to true.
If `TrainSpec.max_steps` is `None`, then "when training ends" is undefined. We are going to train forever. In that case, `is_the_final_export` is going to be always False. I added a note about it.
PiperOrigin-RevId: 171070760
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/python/estimator/exporter.py | 26 | ||||
-rw-r--r-- | tensorflow/python/estimator/exporter_test.py | 41 | ||||
-rw-r--r-- | tensorflow/python/estimator/training.py | 37 | ||||
-rw-r--r-- | tensorflow/python/estimator/training_test.py | 81 |
4 files changed, 169 insertions, 16 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py index 505820dd93..2faca11f6e 100644 --- a/tensorflow/python/estimator/exporter.py +++ b/tensorflow/python/estimator/exporter.py @@ -40,7 +40,8 @@ class Exporter(object): pass @abc.abstractmethod - def export(self, estimator, export_path, checkpoint_path, eval_result): + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): """Exports the given `Estimator` to a specific format. Args: @@ -48,6 +49,13 @@ class Exporter(object): export_path: A string containing a directory where to write the export. checkpoint_path: The checkpoint path to export. eval_result: The output of `Estimator.evaluate` on this checkpoint. + is_the_final_export: This boolean is True when this is an export in the + end of training. It is False for the intermediate exports during + the training. + + When passing `Exporter` to `tf.estimator.train_and_evaluate` + `is_the_final_export` is always False if `TrainSpec.max_steps` is + `None`. Returns: The string path to the exported directory or `None` if export is skipped. @@ -66,7 +74,8 @@ class LatestExporter(Exporter): serving_input_fn, assets_extra=None, as_text=False, - exports_to_keep=5): + exports_to_keep=5, + only_the_final_export=False): """Create an `Exporter` to use with `tf.estimator.EvalSpec`. Args: @@ -86,6 +95,8 @@ class LatestExporter(Exporter): exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to `None` to disable garbage collection. + only_the_final_export: Only the final export in the end of training will + happen if this is set to True. Raises: ValueError: if any arguments is invalid. @@ -95,6 +106,8 @@ class LatestExporter(Exporter): self._assets_extra = assets_extra self._as_text = as_text self._exports_to_keep = exports_to_keep + self._only_the_final_export = only_the_final_export + if exports_to_keep is not None and exports_to_keep <= 0: raise ValueError( '`exports_to_keep`, if provided, must be positive number') @@ -103,7 +116,14 @@ class LatestExporter(Exporter): def name(self): return self._name - def export(self, estimator, export_path, checkpoint_path, eval_result): + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + if not is_the_final_export and self._only_the_final_export: + return None + + if is_the_final_export: + tf_logging.info('Performing the final export in the end of training.') + export_result = estimator.export_savedmodel( export_path, self._serving_input_fn, diff --git a/tensorflow/python/estimator/exporter_test.py b/tensorflow/python/estimator/exporter_test.py index 2ceff1bfd6..01582ac595 100644 --- a/tensorflow/python/estimator/exporter_test.py +++ b/tensorflow/python/estimator/exporter_test.py @@ -42,7 +42,7 @@ class LatestExporterTest(test.TestCase): serving_input_fn=_serving_input_fn, exports_to_keep=0) - def test_saved_model_exporter(self): + def test_latest_exporter(self): def _serving_input_fn(): pass @@ -60,7 +60,42 @@ class LatestExporterTest(test.TestCase): estimator.export_savedmodel.return_value = "export_result_path" export_result = exporter.export(estimator, export_dir_base, - "checkpoint_path", {}) + "checkpoint_path", {}, False) + + self.assertEqual("export_result_path", export_result) + estimator.export_savedmodel.assert_called_with( + export_dir_base, + _serving_input_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + checkpoint_path="checkpoint_path") + + def test_only_the_last_export_is_saved(self): + + def _serving_input_fn(): + pass + + export_dir_base = tempfile.mkdtemp() + "export/" + gfile.MkDir(export_dir_base) + + exporter = exporter_lib.LatestExporter( + name="latest_exporter", + serving_input_fn=_serving_input_fn, + assets_extra={"from/path": "to/path"}, + as_text=False, + exports_to_keep=5, + only_the_final_export=True) + estimator = test.mock.Mock(spec=estimator_lib.Estimator) + estimator.export_savedmodel.return_value = "export_result_path" + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {}, False) + + self.assertFalse(estimator.export_savedmodel.called) + self.assertEqual(None, export_result) + + export_result = exporter.export(estimator, export_dir_base, + "checkpoint_path", {}, True) self.assertEqual("export_result_path", export_result) estimator.export_savedmodel.assert_called_with( @@ -93,7 +128,7 @@ class LatestExporterTest(test.TestCase): estimator = test.mock.Mock(spec=estimator_lib.Estimator) # Garbage collect all but the most recent 2 exports, # where recency is determined based on the timestamp directory names. - exporter.export(estimator, export_dir_base, None, None) + exporter.export(estimator, export_dir_base, None, None, False) self.assertFalse(gfile.Exists(export_dir_1)) self.assertFalse(gfile.Exists(export_dir_2)) diff --git a/tensorflow/python/estimator/training.py b/tensorflow/python/estimator/training.py index 1bed19760b..0a558a67b9 100644 --- a/tensorflow/python/estimator/training.py +++ b/tensorflow/python/estimator/training.py @@ -519,8 +519,11 @@ class _TrainingExecutor(object): class NewCheckpointListener( basic_session_run_hooks.CheckpointSaverListener): - def __init__(self, estimator, eval_spec): - self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec) # pylint: disable=protected-access + def __init__(self, estimator, eval_spec, max_training_steps): + # pylint: disable=protected-access + self._evaluator = _TrainingExecutor._Evaluator(estimator, eval_spec, + max_training_steps) + # pylint: enable=protected-access def after_save(self, session, global_step_value): del session, global_step_value @@ -528,8 +531,10 @@ class _TrainingExecutor(object): # When the underlying `Estimator` object saves a new checkpoint, we would # like this callback to be called so that evaluation and export can trigger. - saving_listeners = [NewCheckpointListener(self._estimator, self._eval_spec)] - + saving_listeners = [ + NewCheckpointListener(self._estimator, self._eval_spec, + self._train_spec.max_steps) + ] return self._start_distributed_training(saving_listeners=saving_listeners) def run_evaluator(self): @@ -566,7 +571,8 @@ class _TrainingExecutor(object): 'after {} secs (eval_spec.throttle_secs) or training is ' 'finished.'.format(self._eval_spec.throttle_secs)) - evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec) + evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, + self._train_spec.max_steps) while True: self._estimator.train( @@ -636,7 +642,8 @@ class _TrainingExecutor(object): time.sleep(start_delay_secs) latest_eval_result = None - evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec) + evaluator = _TrainingExecutor._Evaluator(self._estimator, self._eval_spec, + self._train_spec.max_steps) while True: if latest_eval_result: @@ -663,11 +670,12 @@ class _TrainingExecutor(object): class _Evaluator(object): """A helper class to call `Estimator.evaluate` and export model.""" - def __init__(self, estimator, eval_spec): + def __init__(self, estimator, eval_spec, max_training_steps): self._estimator = estimator self._eval_spec = eval_spec self._previous_ckpt_path = None self._last_warning_time = 0 + self._max_training_steps = max_training_steps def evaluate_and_export(self): """Evaluate and (maybe) export the current model. @@ -712,7 +720,14 @@ class _TrainingExecutor(object): 'Internal error: `Estimator.evaluate` result should have ' '`global_step` in result. Given {}'.format(eval_result)) - self._export_eval_result(eval_result, latest_ckpt_path) + # TODO(isaprykin): There is a potential race condition here in the + # distributed setting. The worker job that performs training + # might stop at a later global step value than the evalutor job. + is_the_final_export = (eval_result[ops.GraphKeys.GLOBAL_STEP] >= + self._max_training_steps + if self._max_training_steps else False) + self._export_eval_result(eval_result, latest_ckpt_path, + is_the_final_export) self._last_warning_time = 0 self._previous_ckpt_path = latest_ckpt_path @@ -725,7 +740,8 @@ class _TrainingExecutor(object): logging.warning(message) self._last_warning_time = current_time - def _export_eval_result(self, eval_result, checkpoint_path): + def _export_eval_result(self, eval_result, checkpoint_path, + is_the_final_export): """Export `eval_result` according to exporters in `EvalSpec`.""" export_dir_base = os.path.join( compat.as_str_any(self._estimator.model_dir), @@ -738,4 +754,5 @@ class _TrainingExecutor(object): compat.as_str_any(export_dir_base), compat.as_str_any(exporter.name)), checkpoint_path=checkpoint_path, - eval_result=eval_result) + eval_result=eval_result, + is_the_final_export=is_the_final_export) diff --git a/tensorflow/python/estimator/training_test.py b/tensorflow/python/estimator/training_test.py index e4c400ca7f..08d11d7d25 100644 --- a/tensorflow/python/estimator/training_test.py +++ b/tensorflow/python/estimator/training_test.py @@ -802,6 +802,46 @@ class TrainingExecutorRunEvaluatorTest(test.TestCase): self.assertEqual(2, mock_est.evaluate.call_count) self.assertEqual(2, exporter.export.call_count) + def test_final_export_is_true_in_the_end(self): + training_max_step = 200 + + mock_est = test.mock.Mock(spec=estimator_lib.Estimator) + mock_est.model_dir = compat.as_bytes(test.get_temp_dir()) + mock_est.evaluate.side_effect = [ + {_GLOBAL_STEP_KEY: training_max_step // 2}, + {_GLOBAL_STEP_KEY: training_max_step} + ] + mock_est.latest_checkpoint.side_effect = ['path_1', 'path_2'] + + mock_train_spec = test.mock.Mock(spec=training.TrainSpec) + mock_train_spec.max_steps = training_max_step + + mock_est.times_export_fn_was_called = 0 + mock_est.times_the_final_export_was_true = 0 + def export(estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + del export_path, checkpoint_path, eval_result + estimator.times_export_fn_was_called += 1 + if is_the_final_export: + estimator.times_the_final_export_was_true += 1 + + exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) + exporter.name = 'see_how_many_times_export_is_called' + exporter.export = export + + eval_spec = training.EvalSpec( + input_fn=lambda: 1, + start_delay_secs=0, + throttle_secs=0, + exporters=exporter) + + executor = training._TrainingExecutor(mock_est, mock_train_spec, eval_spec) + executor.run_evaluator() + + self.assertEqual(2, mock_est.evaluate.call_count) + self.assertEqual(2, mock_est.times_export_fn_was_called) + self.assertEqual(1, mock_est.times_the_final_export_was_true) + def test_skip_evaluation_due_to_ckpt(self): training_max_step = 200 mock_est = test.mock.Mock(spec=estimator_lib.Estimator) @@ -1134,6 +1174,47 @@ class TrainingExecutorRunLocalTest(test.TestCase): with self.assertRaisesRegexp(RuntimeError, _STALE_CHECKPOINT_MSG): executor.run_local() + def test_final_export_is_true_in_the_end(self): + mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') + mock_est.latest_checkpoint = self.unique_checkpoint_every_time_fn + + mock_est.times_export_fn_was_called = 0 + mock_est.times_the_final_export_was_true = 0 + def export(estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + del export_path, checkpoint_path, eval_result + estimator.times_export_fn_was_called += 1 + if is_the_final_export: + estimator.times_the_final_export_was_true += 1 + + exporter = test.mock.PropertyMock(spec=exporter_lib.Exporter) + exporter.name = 'see_how_many_times_export_is_called' + exporter.export = export + + train_spec = training.TrainSpec( + input_fn=lambda: 1, max_steps=300, hooks=[_FakeHook()]) + eval_spec = training.EvalSpec( + input_fn=lambda: 1, + hooks=[_FakeHook()], + throttle_secs=100, + exporters=exporter) + # should be called 3 times. + mock_est.evaluate.side_effect = [{ + _GLOBAL_STEP_KEY: train_spec.max_steps - 100 + }, { + _GLOBAL_STEP_KEY: train_spec.max_steps - 50 + }, { + _GLOBAL_STEP_KEY: train_spec.max_steps + }] + + executor = training._TrainingExecutor(mock_est, train_spec, eval_spec) + executor.run_local() + + self.assertEqual(3, mock_est.train.call_count) + self.assertEqual(3, mock_est.evaluate.call_count) + self.assertEqual(3, mock_est.times_export_fn_was_called) + self.assertEqual(1, mock_est.times_the_final_export_was_true) + def test_train_and_evaluate_args(self): mock_est = test.mock.Mock(spec=estimator_lib.Estimator, model_dir='path/') mock_est.latest_checkpoint.return_value = 'checkpoint_path/' |