diff options
author | 2018-06-04 07:00:10 +0200 | |
---|---|---|
committer | 2018-06-03 22:00:10 -0700 | |
commit | 44c191906d1e4041b490512facc028a23585717b (patch) | |
tree | 9ccb1def33c0a5d463452de1bf57dbc7d4853ea0 /tensorflow/contrib/predictor/core_estimator_predictor.py | |
parent | 96788111224e05de619ac2049fb696ae39f1c257 (diff) |
Support session config in tf.contrib.predictor (#19542)
* Support session config in tf.contrib.predictor
This PR allows users to supply a custom session config uses by the predictor.
This can be essential for some GPU setups in order to play nicely with other processes running on the same GPU.
* Test passing session config to tf.contrib.predictor
Diffstat (limited to 'tensorflow/contrib/predictor/core_estimator_predictor.py')
-rw-r--r-- | tensorflow/contrib/predictor/core_estimator_predictor.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py index d78d94c269..a725072e72 100644 --- a/tensorflow/contrib/predictor/core_estimator_predictor.py +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor): estimator, serving_input_receiver_fn, output_key=None, - graph=None): + graph=None, + config=None): """Initialize a `CoreEstimatorPredictor`. Args: @@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor): `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. graph: Optional. The Tensorflow `graph` in which prediction should be done. + config: `ConfigProto` proto used to configure the session. """ self._graph = graph or ops.Graph() with self._graph.as_default(): @@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor): checkpoint_dir = estimator.model_dir self._session = monitored_session.MonitoredSession( session_creator=monitored_session.ChiefSessionCreator( + config=config, checkpoint_dir=checkpoint_dir)) feed_tensor_info = signature_def.inputs |