aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-13 09:54:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-13 09:57:54 -0800
commite31f38913d4018c2cee094e05a04833ac96f8b68 (patch)
tree41a75221c1709de8aebdd6e2299d533f39aa23fd /tensorflow/contrib/predictor
parent185c593cb71cb6d8116ba05c97e9385642648f1b (diff)
Fix 'tags' parameter in predictor_factories.load_from_model.
tags was incorrectly being mapped to inputs. Added basic unit tests. PiperOrigin-RevId: 178916192
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r--tensorflow/contrib/predictor/BUILD11
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py23
-rw-r--r--tensorflow/contrib/predictor/predictor_factories_test.py51
3 files changed, 72 insertions, 13 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
index 82cd7b4c8a..d7c3d6c3be 100644
--- a/tensorflow/contrib/predictor/BUILD
+++ b/tensorflow/contrib/predictor/BUILD
@@ -137,6 +137,17 @@ py_test(
)
py_test(
+ name = "predictor_factories_test",
+ srcs = ["predictor_factories_test.py"],
+ data = [":test_export_dir"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":predictor_factories",
+ ],
+)
+
+py_test(
name = "core_estimator_predictor_test",
srcs = ["core_estimator_predictor_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
index e3f30d917d..9485187c5d 100644
--- a/tensorflow/contrib/predictor/predictor_factories.py
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Factory functions for `Predictor`s."""
from __future__ import absolute_import
@@ -59,9 +58,9 @@ def from_contrib_estimator(estimator,
return contrib_estimator_predictor.ContribEstimatorPredictor(
estimator,
prediction_input_fn,
- input_alternative_key,
- output_alternative_key,
- graph)
+ input_alternative_key=input_alternative_key,
+ output_alternative_key=output_alternative_key,
+ graph=graph)
def from_estimator(estimator,
@@ -92,10 +91,7 @@ def from_estimator(estimator,
'tf.contrib.learn.Estimator. You likely want to call '
'from_contrib_estimator.')
return core_estimator_predictor.CoreEstimatorPredictor(
- estimator,
- serving_input_receiver_fn,
- output_key,
- graph)
+ estimator, serving_input_receiver_fn, output_key=output_key, graph=graph)
def from_saved_model(export_dir,
@@ -125,8 +121,9 @@ def from_saved_model(export_dir,
ValueError: More than one of `signature_def_key` and `signature_def` is
specified.
"""
- return saved_model_predictor.SavedModelPredictor(export_dir,
- signature_def_key,
- signature_def,
- tags,
- graph)
+ return saved_model_predictor.SavedModelPredictor(
+ export_dir,
+ signature_def_key=signature_def_key,
+ signature_def=signature_def,
+ tags=tags,
+ graph=graph)
diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py
new file mode 100644
index 0000000000..60ffeec653
--- /dev/null
+++ b/tensorflow/contrib/predictor/predictor_factories_test.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for predictor.predictor_factories."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.predictor import predictor_factories
+from tensorflow.python.platform import test
+
+MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
+
+
+class PredictorFactoriesTest(test.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # Load a saved model exported from the arithmetic `Estimator`.
+ # See `testing_common.py`.
+ cls._export_dir = test.test_src_dir_path(MODEL_DIR_NAME)
+
+ def testFromSavedModel(self):
+ """Test loading from_saved_model."""
+ predictor_factories.from_saved_model(self._export_dir)
+
+ def testFromSavedModelWithTags(self):
+ """Test loading from_saved_model with tags."""
+ predictor_factories.from_saved_model(self._export_dir, tags='serve')
+
+ def testFromSavedModelWithBadTags(self):
+ """Test that loading fails for bad tags."""
+ bad_tags_regex = ('.*? could not be found in SavedModel')
+ with self.assertRaisesRegexp(RuntimeError, bad_tags_regex):
+ predictor_factories.from_saved_model(self._export_dir, tags='bad_tag')
+
+
+if __name__ == '__main__':
+ test.main()