diff options
Diffstat (limited to 'tensorflow/contrib/predictor/predictor_factories_test.py')
-rw-r--r-- | tensorflow/contrib/predictor/predictor_factories_test.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py index 578d9424b2..a2ef1dc3af 100644 --- a/tensorflow/contrib/predictor/predictor_factories_test.py +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -20,6 +20,7 @@ from __future__ import print_function from tensorflow.contrib.predictor import predictor_factories from tensorflow.contrib.predictor import testing_common +from tensorflow.core.protobuf import config_pb2 from tensorflow.python.platform import test MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' @@ -41,6 +42,11 @@ class PredictorFactoriesTest(test.TestCase): """Test loading from_saved_model with tags.""" predictor_factories.from_saved_model(self._export_dir, tags='serve') + def testFromSavedModelWithSessionConfig(self): + """Test loading from_saved_model with session config.""" + predictor_factories.from_saved_model( + self._export_dir, config=config_pb2.ConfigProto()) + def testFromSavedModelWithBadTags(self): """Test that loading fails for bad tags.""" bad_tags_regex = ('.*? could not be found in SavedModel') @@ -53,6 +59,13 @@ class PredictorFactoriesTest(test.TestCase): predictor_factories.from_contrib_estimator( estimator, input_fn, output_alternative_key='sum') + def testFromContribEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=False) + input_fn = testing_common.get_arithmetic_input_fn(core=False) + predictor_factories.from_contrib_estimator( + estimator, input_fn, output_alternative_key='sum', + config=config_pb2.ConfigProto()) + def testFromContribEstimatorWithCoreEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=True) input_fn = testing_common.get_arithmetic_input_fn(core=True) @@ -64,6 +77,12 @@ class PredictorFactoriesTest(test.TestCase): input_fn = testing_common.get_arithmetic_input_fn(core=True) predictor_factories.from_estimator(estimator, input_fn) + def testFromCoreEstimatorWithSessionConfig(self): + estimator = testing_common.get_arithmetic_estimator(core=True) + input_fn = testing_common.get_arithmetic_input_fn(core=True) + predictor_factories.from_estimator( + estimator, input_fn, config=config_pb2.ConfigProto()) + def testFromCoreEstimatorWithContribEstimatorRaises(self): estimator = testing_common.get_arithmetic_estimator(core=False) input_fn = testing_common.get_arithmetic_input_fn(core=False) |