diff options
Diffstat (limited to 'tensorflow/python/estimator/estimator.py')
-rw-r--r-- | tensorflow/python/estimator/estimator.py | 11 |
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 9891ae4eaf..36918af552 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -266,7 +266,11 @@ class Estimator(object): checkpoint_path=checkpoint_path, name=name) - def predict(self, input_fn, predict_keys=None, hooks=None): + def predict(self, + input_fn, + predict_keys=None, + hooks=None, + checkpoint_path=None): """Returns predictions for given features. Args: @@ -281,6 +285,8 @@ class Estimator(object): `None`, returns all. hooks: List of `SessionRunHook` subclass instances. Used for callbacks inside the prediction call. + checkpoint_path: Path of a specific checkpoint to predict. If `None`, the + latest checkpoint in `model_dir` is used. Yields: Evaluated values of `predictions` tensors. @@ -294,7 +300,8 @@ class Estimator(object): """ hooks = _check_hooks_type(hooks) # Check that model has been trained. - checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + checkpoint_path = saver.latest_checkpoint(self._model_dir) if not checkpoint_path: raise ValueError('Could not find trained model in model_dir: {}.'.format( self._model_dir)) |