aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor/contrib_estimator_predictor.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/predictor/contrib_estimator_predictor.py')
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
index b7a98c68e2..af3b2ad1b5 100644
--- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `ContribEstimatorPredictor`.
Args:
@@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_filename_with_path=checkpoint_path))
input_alternative_key = (