aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/exporter.py
diff options
context:
space:
mode:
authorGravatar Igor Saprykin <isaprykin@google.com>2017-10-04 15:13:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-10-04 15:17:39 -0700
commit89df2e336218f7f3ecf2c70f8478c64985345ded (patch)
tree2756c54c384716fe68639da5abe3e36472673a65 /tensorflow/python/estimator/exporter.py
parent4486b4f69b55633274f7903158d680bf2e9eabff (diff)
Add the 'is_the_final_export' signal to Exporters. Use them in training.
When the training ends, the final export is performed via `Exporter.export()` call. That final export is going to have is_the_final_export parameter being set to true. If `TrainSpec.max_steps` is `None`, then "when training ends" is undefined. We are going to train forever. In that case, `is_the_final_export` is going to be always False. I added a note about it. PiperOrigin-RevId: 171070760
Diffstat (limited to 'tensorflow/python/estimator/exporter.py')
-rw-r--r--tensorflow/python/estimator/exporter.py26
1 files changed, 23 insertions, 3 deletions
diff --git a/tensorflow/python/estimator/exporter.py b/tensorflow/python/estimator/exporter.py
index 505820dd93..2faca11f6e 100644
--- a/tensorflow/python/estimator/exporter.py
+++ b/tensorflow/python/estimator/exporter.py
@@ -40,7 +40,8 @@ class Exporter(object):
pass
@abc.abstractmethod
- def export(self, estimator, export_path, checkpoint_path, eval_result):
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
"""Exports the given `Estimator` to a specific format.
Args:
@@ -48,6 +49,13 @@ class Exporter(object):
export_path: A string containing a directory where to write the export.
checkpoint_path: The checkpoint path to export.
eval_result: The output of `Estimator.evaluate` on this checkpoint.
+ is_the_final_export: This boolean is True when this is an export in the
+ end of training. It is False for the intermediate exports during
+ the training.
+
+ When passing `Exporter` to `tf.estimator.train_and_evaluate`
+ `is_the_final_export` is always False if `TrainSpec.max_steps` is
+ `None`.
Returns:
The string path to the exported directory or `None` if export is skipped.
@@ -66,7 +74,8 @@ class LatestExporter(Exporter):
serving_input_fn,
assets_extra=None,
as_text=False,
- exports_to_keep=5):
+ exports_to_keep=5,
+ only_the_final_export=False):
"""Create an `Exporter` to use with `tf.estimator.EvalSpec`.
Args:
@@ -86,6 +95,8 @@ class LatestExporter(Exporter):
exports_to_keep: Number of exports to keep. Older exports will be
garbage-collected. Defaults to 5. Set to `None` to disable garbage
collection.
+ only_the_final_export: Only the final export in the end of training will
+ happen if this is set to True.
Raises:
ValueError: if any arguments is invalid.
@@ -95,6 +106,8 @@ class LatestExporter(Exporter):
self._assets_extra = assets_extra
self._as_text = as_text
self._exports_to_keep = exports_to_keep
+ self._only_the_final_export = only_the_final_export
+
if exports_to_keep is not None and exports_to_keep <= 0:
raise ValueError(
'`exports_to_keep`, if provided, must be positive number')
@@ -103,7 +116,14 @@ class LatestExporter(Exporter):
def name(self):
return self._name
- def export(self, estimator, export_path, checkpoint_path, eval_result):
+ def export(self, estimator, export_path, checkpoint_path, eval_result,
+ is_the_final_export):
+ if not is_the_final_export and self._only_the_final_export:
+ return None
+
+ if is_the_final_export:
+ tf_logging.info('Performing the final export in the end of training.')
+
export_result = estimator.export_savedmodel(
export_path,
self._serving_input_fn,