aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/estimator/estimator.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r--tensorflow/python/estimator/estimator.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 9891ae4eaf..36918af552 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -266,7 +266,11 @@ class Estimator(object):
checkpoint_path=checkpoint_path,
name=name)
- def predict(self, input_fn, predict_keys=None, hooks=None):
+ def predict(self,
+ input_fn,
+ predict_keys=None,
+ hooks=None,
+ checkpoint_path=None):
"""Returns predictions for given features.
Args:
@@ -281,6 +285,8 @@ class Estimator(object):
`None`, returns all.
hooks: List of `SessionRunHook` subclass instances. Used for callbacks
inside the prediction call.
+ checkpoint_path: Path of a specific checkpoint to predict. If `None`, the
+ latest checkpoint in `model_dir` is used.
Yields:
Evaluated values of `predictions` tensors.
@@ -294,7 +300,8 @@ class Estimator(object):
"""
hooks = _check_hooks_type(hooks)
# Check that model has been trained.
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
if not checkpoint_path:
raise ValueError('Could not find trained model in model_dir: {}.'.format(
self._model_dir))