diff options
author | 2017-10-04 15:13:33 -0700 | |
---|---|---|
committer | 2017-10-04 15:17:39 -0700 | |
commit | 89df2e336218f7f3ecf2c70f8478c64985345ded (patch) | |
tree | 2756c54c384716fe68639da5abe3e36472673a65 /tensorflow/python/estimator/exporter.py | |
parent | 4486b4f69b55633274f7903158d680bf2e9eabff (diff) |
Add the 'is_the_final_export' signal to Exporters. Use them in training.
When the training ends, the final export is performed via `Exporter.export()` call. That final export is going to have is_the_final_export parameter being set to true.
If `TrainSpec.max_steps` is `None`, then "when training ends" is undefined. We are going to train forever. In that case, `is_the_final_export` is going to be always False. I added a note about it.
PiperOrigin-RevId: 171070760
Diffstat (limited to 'tensorflow/python/estimator/exporter.py')
-rw-r--r-- | tensorflow/python/estimator/exporter.py | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py index 505820dd93..2faca11f6e 100644 --- a/tensorflow/python/estimator/exporter.py +++ b/tensorflow/python/estimator/exporter.py @@ -40,7 +40,8 @@ class Exporter(object): pass @abc.abstractmethod - def export(self, estimator, export_path, checkpoint_path, eval_result): + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): """Exports the given `Estimator` to a specific format. Args: @@ -48,6 +49,13 @@ class Exporter(object): export_path: A string containing a directory where to write the export. checkpoint_path: The checkpoint path to export. eval_result: The output of `Estimator.evaluate` on this checkpoint. + is_the_final_export: This boolean is True when this is an export in the + end of training. It is False for the intermediate exports during + the training. + + When passing `Exporter` to `tf.estimator.train_and_evaluate` + `is_the_final_export` is always False if `TrainSpec.max_steps` is + `None`. Returns: The string path to the exported directory or `None` if export is skipped. @@ -66,7 +74,8 @@ class LatestExporter(Exporter): serving_input_fn, assets_extra=None, as_text=False, - exports_to_keep=5): + exports_to_keep=5, + only_the_final_export=False): """Create an `Exporter` to use with `tf.estimator.EvalSpec`. Args: @@ -86,6 +95,8 @@ class LatestExporter(Exporter): exports_to_keep: Number of exports to keep. Older exports will be garbage-collected. Defaults to 5. Set to `None` to disable garbage collection. + only_the_final_export: Only the final export in the end of training will + happen if this is set to True. Raises: ValueError: if any arguments is invalid. @@ -95,6 +106,8 @@ class LatestExporter(Exporter): self._assets_extra = assets_extra self._as_text = as_text self._exports_to_keep = exports_to_keep + self._only_the_final_export = only_the_final_export + if exports_to_keep is not None and exports_to_keep <= 0: raise ValueError( '`exports_to_keep`, if provided, must be positive number') @@ -103,7 +116,14 @@ class LatestExporter(Exporter): def name(self): return self._name - def export(self, estimator, export_path, checkpoint_path, eval_result): + def export(self, estimator, export_path, checkpoint_path, eval_result, + is_the_final_export): + if not is_the_final_export and self._only_the_final_export: + return None + + if is_the_final_export: + tf_logging.info('Performing the final export in the end of training.') + export_result = estimator.export_savedmodel( export_path, self._serving_input_fn, |