diff options
author | 2017-12-13 09:54:52 -0800 | |
---|---|---|
committer | 2017-12-13 09:57:54 -0800 | |
commit | e31f38913d4018c2cee094e05a04833ac96f8b68 (patch) | |
tree | 41a75221c1709de8aebdd6e2299d533f39aa23fd /tensorflow/contrib/predictor | |
parent | 185c593cb71cb6d8116ba05c97e9385642648f1b (diff) |
Fix 'tags' parameter in predictor_factories.load_from_model.
tags was incorrectly being mapped to inputs.
Added basic unit tests.
PiperOrigin-RevId: 178916192
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r-- | tensorflow/contrib/predictor/BUILD | 11 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/predictor_factories.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/predictor/predictor_factories_test.py | 51 |
3 files changed, 72 insertions, 13 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD index 82cd7b4c8a..d7c3d6c3be 100644 --- a/tensorflow/contrib/predictor/BUILD +++ b/tensorflow/contrib/predictor/BUILD @@ -137,6 +137,17 @@ py_test( ) py_test( + name = "predictor_factories_test", + srcs = ["predictor_factories_test.py"], + data = [":test_export_dir"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":predictor_factories", + ], +) + +py_test( name = "core_estimator_predictor_test", srcs = ["core_estimator_predictor_test.py"], srcs_version = "PY2AND3", diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py index e3f30d917d..9485187c5d 100644 --- a/tensorflow/contrib/predictor/predictor_factories.py +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== - """Factory functions for `Predictor`s.""" from __future__ import absolute_import @@ -59,9 +58,9 @@ def from_contrib_estimator(estimator, return contrib_estimator_predictor.ContribEstimatorPredictor( estimator, prediction_input_fn, - input_alternative_key, - output_alternative_key, - graph) + input_alternative_key=input_alternative_key, + output_alternative_key=output_alternative_key, + graph=graph) def from_estimator(estimator, @@ -92,10 +91,7 @@ def from_estimator(estimator, 'tf.contrib.learn.Estimator. You likely want to call ' 'from_contrib_estimator.') return core_estimator_predictor.CoreEstimatorPredictor( - estimator, - serving_input_receiver_fn, - output_key, - graph) + estimator, serving_input_receiver_fn, output_key=output_key, graph=graph) def from_saved_model(export_dir, @@ -125,8 +121,9 @@ def from_saved_model(export_dir, ValueError: More than one of `signature_def_key` and `signature_def` is specified. """ - return saved_model_predictor.SavedModelPredictor(export_dir, - signature_def_key, - signature_def, - tags, - graph) + return saved_model_predictor.SavedModelPredictor( + export_dir, + signature_def_key=signature_def_key, + signature_def=signature_def, + tags=tags, + graph=graph) diff --git a/tensorflow/contrib/predictor/predictor_factories_test.py b/tensorflow/contrib/predictor/predictor_factories_test.py new file mode 100644 index 0000000000..60ffeec653 --- /dev/null +++ b/tensorflow/contrib/predictor/predictor_factories_test.py @@ -0,0 +1,51 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for predictor.predictor_factories.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import predictor_factories +from tensorflow.python.platform import test + +MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' + + +class PredictorFactoriesTest(test.TestCase): + + @classmethod + def setUpClass(cls): + # Load a saved model exported from the arithmetic `Estimator`. + # See `testing_common.py`. + cls._export_dir = test.test_src_dir_path(MODEL_DIR_NAME) + + def testFromSavedModel(self): + """Test loading from_saved_model.""" + predictor_factories.from_saved_model(self._export_dir) + + def testFromSavedModelWithTags(self): + """Test loading from_saved_model with tags.""" + predictor_factories.from_saved_model(self._export_dir, tags='serve') + + def testFromSavedModelWithBadTags(self): + """Test that loading fails for bad tags.""" + bad_tags_regex = ('.*? could not be found in SavedModel') + with self.assertRaisesRegexp(RuntimeError, bad_tags_regex): + predictor_factories.from_saved_model(self._export_dir, tags='bad_tag') + + +if __name__ == '__main__': + test.main() |