aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lgeiger@users.noreply.github.com>2018-06-04 07:00:10 +0200
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-03 22:00:10 -0700
commit44c191906d1e4041b490512facc028a23585717b (patch)
tree9ccb1def33c0a5d463452de1bf57dbc7d4853ea0 /tensorflow/contrib/predictor
parent96788111224e05de619ac2049fb696ae39f1c257 (diff)
Support session config in tf.contrib.predictor (#19542)
* Support session config in tf.contrib.predictor This PR allows users to supply a custom session config uses by the predictor. This can be essential for some GPU setups in order to play nicely with other processes running on the same GPU. * Test passing session config to tf.contrib.predictor
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor.py5
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py24
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py19
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py6
5 files changed, 49 insertions, 10 deletions
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
index b7a98c68e2..af3b2ad1b5 100644
--- a/tensorflow/contrib/predictor/contrib_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -34,7 +34,8 @@ class ContribEstimatorPredictor(predictor.Predictor):
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `ContribEstimatorPredictor`.
Args:
@@ -48,6 +49,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -58,6 +60,7 @@ class ContribEstimatorPredictor(predictor.Predictor):
checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_filename_with_path=checkpoint_path))
input_alternative_key = (
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py
index d78d94c269..a725072e72 100644
--- a/tensorflow/contrib/predictor/core_estimator_predictor.py
+++ b/tensorflow/contrib/predictor/core_estimator_predictor.py
@@ -51,7 +51,8 @@ class CoreEstimatorPredictor(predictor.Predictor):
estimator,
serving_input_receiver_fn,
output_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
@@ -62,6 +63,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
"""
self._graph = graph or ops.Graph()
with self._graph.as_default():
@@ -71,6 +73,7 @@ class CoreEstimatorPredictor(predictor.Predictor):
checkpoint_dir = estimator.model_dir
self._session = monitored_session.MonitoredSession(
session_creator=monitored_session.ChiefSessionCreator(
+ config=config,
checkpoint_dir=checkpoint_dir))
feed_tensor_info = signature_def.inputs
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
index 6e77e934fe..f275bc15ad 100644
--- a/tensorflow/contrib/predictor/predictor_factories.py
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -30,7 +30,8 @@ def from_contrib_estimator(estimator,
prediction_input_fn,
input_alternative_key=None,
output_alternative_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `tf.contrib.learn.Estimator`.
Args:
@@ -44,6 +45,7 @@ def from_contrib_estimator(estimator,
multi-headed models.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -62,13 +64,15 @@ def from_contrib_estimator(estimator,
prediction_input_fn,
input_alternative_key=input_alternative_key,
output_alternative_key=output_alternative_key,
- graph=graph)
+ graph=graph,
+ config=config)
def from_estimator(estimator,
serving_input_receiver_fn,
output_key=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `tf.python.estimator.Estimator`.
Args:
@@ -79,6 +83,7 @@ def from_estimator(estimator,
`None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -93,14 +98,19 @@ def from_estimator(estimator,
'tf.contrib.learn.Estimator. You likely want to call '
'from_contrib_estimator.')
return core_estimator_predictor.CoreEstimatorPredictor(
- estimator, serving_input_receiver_fn, output_key=output_key, graph=graph)
+ estimator,
+ serving_input_receiver_fn,
+ output_key=output_key,
+ graph=graph,
+ config=config)
def from_saved_model(export_dir,
signature_def_key=None,
signature_def=None,
tags=None,
- graph=None):
+ graph=None,
+ config=None):
"""Constructs a `Predictor` from a `SavedModel` on disk.
Args:
@@ -115,6 +125,7 @@ def from_saved_model(export_dir,
`SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Returns:
An initialized `Predictor`.
@@ -128,4 +139,5 @@ def from_saved_model(export_dir,
signature_def_key=signature_def_key,
signature_def=signature_def,
tags=tags,
- graph=graph)
+ graph=graph,
+ config=config)
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
index 578d9424b2..a2ef1dc3af 100644
--- a/tensorflow/contrib/predictor/predictor_factories_test.py
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.predictor import predictor_factories
from tensorflow.contrib.predictor import testing_common
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import test
MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
@@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase):
"""Test loading from_saved_model with tags."""
predictor_factories.from_saved_model(self._export_dir, tags='serve')
+ def testFromSavedModelWithSessionConfig(self):
+ """Test loading from_saved_model with session config."""
+ predictor_factories.from_saved_model(
+ self._export_dir, config=config_pb2.ConfigProto())
+
def testFromSavedModelWithBadTags(self):
"""Test that loading fails for bad tags."""
bad_tags_regex = ('.*? could not be found in SavedModel')
@@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase):
predictor_factories.from_contrib_estimator(
estimator, input_fn, output_alternative_key='sum')
+ def testFromContribEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=False)
+ input_fn = testing_common.get_arithmetic_input_fn(core=False)
+ predictor_factories.from_contrib_estimator(
+ estimator, input_fn, output_alternative_key='sum',
+ config=config_pb2.ConfigProto())
+
def testFromContribEstimatorWithCoreEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=True)
input_fn = testing_common.get_arithmetic_input_fn(core=True)
@@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase):
input_fn = testing_common.get_arithmetic_input_fn(core=True)
predictor_factories.from_estimator(estimator, input_fn)
+ def testFromCoreEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=True)
+ input_fn = testing_common.get_arithmetic_input_fn(core=True)
+ predictor_factories.from_estimator(
+ estimator, input_fn, config=config_pb2.ConfigProto())
+
def testFromCoreEstimatorWithContribEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=False)
input_fn = testing_common.get_arithmetic_input_fn(core=False)
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
index 0dbca0f813..95da6d04ed 100644
--- a/tensorflow/contrib/predictor/saved_model_predictor.py
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -121,7 +121,8 @@ class SavedModelPredictor(predictor.Predictor):
input_names=None,
output_names=None,
tags=None,
- graph=None):
+ graph=None,
+ config=None):
"""Initialize a `CoreEstimatorPredictor`.
Args:
@@ -142,6 +143,7 @@ class SavedModelPredictor(predictor.Predictor):
the correct `SignatureDef`. Defaults to `DEFAULT_TAGS`.
graph: Optional. The Tensorflow `graph` in which prediction should be
done.
+ config: `ConfigProto` proto used to configure the session.
Raises:
ValueError: If more than one of signature_def_key OR signature_def OR
(input_names AND output_names) is specified.
@@ -152,7 +154,7 @@ class SavedModelPredictor(predictor.Predictor):
self._graph = graph or ops.Graph()
with self._graph.as_default():
- self._session = session.Session()
+ self._session = session.Session(config=config)
loader.load(self._session, tags.split(','), export_dir)
if input_names is None: