diff options
Diffstat (limited to 'tensorflow/contrib/predictor/contrib_estimator_predictor.py')
-rw-r--r-- | tensorflow/contrib/predictor/contrib_estimator_predictor.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py index b7a98c68e2..af3b2ad1b5 100644 --- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor): prediction_input_fn, input_alternative_key=None, output_alternative_key=None, - graph=None): + graph=None, + config=None): """Initialize a `ContribEstimatorPredictor`. Args: @@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor): multi-headed models. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor): checkpoint_path = saver.latest_checkpoint(estimator.model_dir) self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_filename_with_path=checkpoint_path)) input_alternative_key = ( |