aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor/core_estimator_predictor.py
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lgeiger@users.noreply.github.com>2018-06-04 07:00:10 +0200
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-03 22:00:10 -0700
commit44c191906d1e4041b490512facc028a23585717b (patch)
tree9ccb1def33c0a5d463452de1bf57dbc7d4853ea0 /tensorflow/contrib/predictor/core_estimator_predictor.py
parent96788111224e05de619ac2049fb696ae39f1c257 (diff)
Support session config in tf.contrib.predictor (#19542)
* Support session config in tf.contrib.predictor This PR allows users to supply a custom session config uses by the predictor. This can be essential for some GPU setups in order to play nicely with other processes running on the same GPU. * Test passing session config to tf.contrib.predictor
Diffstat (limited to 'tensorflow/contrib/predictor/core_estimator_predictor.py')
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py
index d78d94c269..a725072e72 100644
--- a/tensorflow/contrib/predictor/core_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/core_estimator_predictor.py
@@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor):
estimator,
serving_input_receiver_fn,
output_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
@@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
checkpoint_dir = estimator.model_dir
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_dir=checkpoint_dir))
feed_tensor_info = signature_def.inputs