aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor/predictor_factories_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/predictor/predictor_factories_test.py')
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
index 578d9424b2..a2ef1dc3af 100644
--- a/tensorflow/contrib/predictor/predictor_factories_test.py
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.contrib.predictor import predictor_factories
from tensorflow.contrib.predictor import testing_common
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import test
MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
@@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase):
"""Test loading from_saved_model with tags."""
predictor_factories.from_saved_model(self._export_dir, tags='serve')
+ def testFromSavedModelWithSessionConfig(self):
+ """Test loading from_saved_model with session config."""
+ predictor_factories.from_saved_model(
+ self._export_dir, config=config_pb2.ConfigProto())
+
def testFromSavedModelWithBadTags(self):
"""Test that loading fails for bad tags."""
bad_tags_regex = ('.*? could not be found in SavedModel')
@@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase):
predictor_factories.from_contrib_estimator(
estimator, input_fn, output_alternative_key='sum')
+ def testFromContribEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=False)
+ input_fn = testing_common.get_arithmetic_input_fn(core=False)
+ predictor_factories.from_contrib_estimator(
+ estimator, input_fn, output_alternative_key='sum',
+ config=config_pb2.ConfigProto())
+
def testFromContribEstimatorWithCoreEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=True)
input_fn = testing_common.get_arithmetic_input_fn(core=True)
@@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase):
input_fn = testing_common.get_arithmetic_input_fn(core=True)
predictor_factories.from_estimator(estimator, input_fn)
+ def testFromCoreEstimatorWithSessionConfig(self):
+ estimator = testing_common.get_arithmetic_estimator(core=True)
+ input_fn = testing_common.get_arithmetic_input_fn(core=True)
+ predictor_factories.from_estimator(
+ estimator, input_fn, config=config_pb2.ConfigProto())
+
def testFromCoreEstimatorWithContribEstimatorRaises(self):
estimator = testing_common.get_arithmetic_estimator(core=False)
input_fn = testing_common.get_arithmetic_input_fn(core=False)