diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-06-15 12:53:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-06-15 12:57:26 -0700 |
commit | fa75f26351f42e4fd3fc89b553d7919a6f147e41 (patch) | |
tree | d40a605c1e4ab22381c680897c0ead646a86ac3c /tensorflow/contrib/predictor | |
parent | fa927634cbda0b7826e516336fb8277707fb6fe3 (diff) |
Introduce Predictor, an interface for efficient, repeated inference.
PiperOrigin-RevId: 159141010
Diffstat (limited to 'tensorflow/contrib/predictor')
15 files changed, 1212 insertions, 0 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD new file mode 100644 index 0000000000..c4b46551c1 --- /dev/null +++ b/tensorflow/contrib/predictor/BUILD @@ -0,0 +1,163 @@ +# `Predictor` classes provide an interface for efficient, repeated inference. + +package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("//tensorflow:tensorflow.bzl", "py_test") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + ], + ), + visibility = ["//tensorflow:__subpackages__"], +) + +py_library( + name = "predictor", + srcs = [ + "__init__.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + ":predictor_factories", + ], +) + +py_library( + name = "predictor_factories", + srcs = [ + "predictor_factories.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":contrib_estimator_predictor", + ":core_estimator_predictor", + ":saved_model_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "base_predictor", + srcs = [ + "predictor.py", + ], + srcs_version = "PY2AND3", +) + +py_library( + name = "saved_model_predictor", + srcs = [ + "saved_model_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/python/tools:saved_model_cli", + ], +) + +py_library( + name = "core_estimator_predictor", + srcs = [ + "core_estimator_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "contrib_estimator_predictor", + srcs = [ + "contrib_estimator_predictor.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":base_predictor", + "//tensorflow/contrib/learn", + ], +) + +py_library( + name = "testing_common", + srcs = [ + "testing_common.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ], +) + +# Transitive dependencies of this target will be included in the pip package. +py_library( + name = "predictor_pip", + visibility = ["//visibility:public"], + deps = [ + ":contrib_estimator_predictor", + ":core_estimator_predictor", + ":saved_model_predictor", + ], +) + +py_test( + name = "saved_model_predictor_test", + srcs = [ + "saved_model_predictor_test.py", + ], + data = [":test_export_dir"], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":saved_model_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "core_estimator_predictor_test", + srcs = [ + "core_estimator_predictor_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":core_estimator_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + ], +) + +py_test( + name = "contrib_estimator_predictor_test", + srcs = [ + "contrib_estimator_predictor_test.py", + ], + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + ":contrib_estimator_predictor", + ":testing_common", + "//tensorflow/python:client_testlib", + "//third_party/py/numpy", + ], +) + +filegroup( + name = "test_export_dir", + srcs = glob(["test_export_dir/**/*"]), + tags = ["nopip"], +) diff --git a/tensorflow/contrib/predictor/README.md b/tensorflow/contrib/predictor/README.md new file mode 100644 index 0000000000..16cdcf3e70 --- /dev/null +++ b/tensorflow/contrib/predictor/README.md @@ -0,0 +1,96 @@ +# Predictors + +The `Predictor` classes provide a simple interface for performing repeated, +efficient inference. A `Predictor` can be constructed from a `SavedModel` on +disk, a `tf.Estimator` or a `tf.contrib.Estimator`. + +To facilitate the examples below, let's define a trivial `Estimator` that just +calculates a sum: + +```python +def model_fn(features, labels, mode): + z = tf.add(features['x'], features['y'], name='z') + return tf.contrib.learn.ModelFnOps( + mode, {'z': z}, loss=tf.constant(0.0), train_op=tf.no_op()) + +estimator = tf.contrib.learn.Estimator(model_fn=model_fn) +``` + +We can then construct a `Predictor` in two different ways. + +## `Predictor` from a `SavedModel` + +Given a trained `Estimator`, we first export a `SavedModel`: + +```python +def serving_input_fn(): + x = tf.placeholder(dtype=tf.float32, shape=[None], name='x') + y = tf.placeholder(dtype=tf.float32, shape=[None], name='y') + + features = {'x': x, 'y': y} + return tf.contrib.learn.utils.input_fn_utils.InputFnOps( + features, None, default_inputs=features) + +saved_model_dir = estimator.export_savedmodel(my_export_dir, serving_input_fn) +``` + +We can then construct a `Predictor` as follows: + +```python +saved_model_predictor = predictor.from_saved_model(export_dir='test_export_dir') +output_dict = saved_model_predictor({'x': [1.0], 'y': [5.2]}) +# output_dict == {'sum': [6.2]} +``` + +By specifying a signature definition, we can feed and fetch any `Tensor`s in +the `Graph`. In this example, we feed and fetch the same `Tensor`, `z`: + +```python +inputs = outputs = {'z': tf.TensorInfo( + name='z:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + +signature_def = tf.saved_model.signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name='tensorflow/serving/regress') + +trivial_predictor = predictor.from_saved_model( + export_dir=saved_model_dir, + signature_def=signature_def) + +output_dict = trivial_predictor({'z': [32.]}) +# output_dict == {'z': [32.]} +``` + +You can also specify input and output `Tensor`s by name using the `input_names` +and `output_names` keywords: + +```python +saved_model_predictor = predictor.from_saved_model( + export_dir=saved_model_dir, + input_names={'x': 'x:0', 'y': 'y:0'}, + outputs={'z': 'z:0'}) + +output_dict = saved_model_predictor({'x': [6.], 'y': [11.]}) +# output_dict == {'z': [17.]} +``` + +This functionality is particularly useful for performing encoding once, but +doing multiple decoding iterations with e.g. seq2seq models. + +## `Predictor` from an `Estimator` + +We can also construct a `Predictor` directly from an `Estimator`. Defining +`serving_input_fn` as above, + +```python +estimator_predictor = predictor.from_contrib_estimator( + estimator, serving_input_fn) +output_dict = sum_predictor({'x': [1., 2.], 'y': [3., 4.]}) +# output_dict == {'z': [4., 6.]} +``` + +Construction from a `tf.Estimator` is almost identical. + diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py new file mode 100644 index 0000000000..d270c3f798 --- /dev/null +++ b/tensorflow/contrib/predictor/__init__.py @@ -0,0 +1,24 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Modules for `Predictor`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import from_contrib_estimator +from tensorflow.contrib.predictor import from_estimator +from tensorflow.contrib.predictor import from_saved_model diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py new file mode 100644 index 0000000000..b7a98c68e2 --- /dev/null +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py @@ -0,0 +1,74 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor constructed from a `tf.contrib.learn.Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.contrib.predictor import predictor +from tensorflow.python.framework import ops +from tensorflow.python.training import monitored_session +from tensorflow.python.training import saver + + +class ContribEstimatorPredictor(predictor.Predictor): + """A `Predictor constructed from a `tf.contrib.learn.Estimator`.""" + + def __init__(self, + estimator, + prediction_input_fn, + input_alternative_key=None, + output_alternative_key=None, + graph=None): + """Initialize a `ContribEstimatorPredictor`. + + Args: + estimator: an instance of `tf.contrib.learn.Estimator`. + prediction_input_fn: a function that takes no arguments and returns an + instance of `InputFnOps`. + input_alternative_key: Optional. Specify the input alternative used for + prediction. + output_alternative_key: Specify the output alternative used for + prediction. Not needed for single-headed models but required for + multi-headed models. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + """ + self._graph = graph or ops.Graph() + with self._graph.as_default(): + input_fn_ops = prediction_input_fn() + # pylint: disable=protected-access + model_fn_ops = estimator._get_predict_ops(input_fn_ops.features) + # pylint: enable=protected-access + checkpoint_path = saver.latest_checkpoint(estimator.model_dir) + self._session = monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + checkpoint_filename_with_path=checkpoint_path)) + + input_alternative_key = ( + input_alternative_key or + saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY) + input_alternatives, _ = saved_model_export_utils.get_input_alternatives( + input_fn_ops) + self._feed_tensors = input_alternatives[input_alternative_key] + + (output_alternatives, + output_alternative_key) = saved_model_export_utils.get_output_alternatives( + model_fn_ops, output_alternative_key) + _, fetch_tensors = output_alternatives[output_alternative_key] + self._fetch_tensors = fetch_tensors diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py new file mode 100644 index 0000000000..4b97a52b1a --- /dev/null +++ b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py @@ -0,0 +1,70 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.contrib_estimator_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np + +from tensorflow.contrib.predictor import contrib_estimator_predictor +from tensorflow.contrib.predictor import testing_common +from tensorflow.python.platform import test + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + + +class ContribEstimatorPredictorTest(test.TestCase): + """Test fixture for `ContribEstimatorPredictor`.""" + + def setUp(self): + model_dir = tempfile.mkdtemp() + self._estimator = testing_common.get_arithmetic_estimator( + core=False, model_dir=model_dir) + self._prediction_input_fn = testing_common.get_arithmetic_input_fn( + core=False, train=False) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signatures.""" + np.random.seed(1234) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = contrib_estimator_predictor.ContribEstimatorPredictor( + estimator=self._estimator, + prediction_input_fn=self._prediction_input_fn, + output_alternative_key=key) + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for output key "{}." ' + 'Got output {} for x = {} and y = {}'.format( + key, output, x, y)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py new file mode 100644 index 0000000000..5557ef5101 --- /dev/null +++ b/tensorflow/contrib/predictor/core_estimator_predictor.py @@ -0,0 +1,80 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import predictor +from tensorflow.python.estimator import model_fn +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.training import monitored_session + + +def _get_signature_def( + serving_input_receiver, estimator, output_key=None): + """Construct a `SignatureDef` proto.""" + if output_key is None: + output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + # pylint: disable=protected-access + estimator_spec = estimator._call_model_fn( + serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT) + # pylint: enable=protected-access + export_outputs = estimator_spec.export_outputs + export_output = export_outputs.get(output_key) + if export_output is None: + raise KeyError('output_key must be one of {}; got {}'.format( + export_outputs.keys(), output_key)) + return export_output.as_signature_def(serving_input_receiver.receiver_tensors) + + +class CoreEstimatorPredictor(predictor.Predictor): + """A `Predictor` constructed from an `learn.python.estimator.Estimator`.""" + + def __init__(self, + estimator, + serving_input_receiver_fn, + output_key=None, + graph=None): + """Initialize a `CoreEstimatorPredictor`. + + Args: + estimator: an instance of `learn.python.estimator.Estimator`. + serving_input_receiver_fn: a function that takes no arguments and returns + an instance of `ServingInputReceiver` compatible with `estimator`. + output_key: Optional string specifying the export output to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + """ + self._graph = graph or ops.Graph() + with self._graph.as_default(): + serving_input_receiver = serving_input_receiver_fn() + signature_def = _get_signature_def( + serving_input_receiver, estimator, output_key) + checkpoint_path = estimator.model_dir + self._session = monitored_session.MonitoredSession( + session_creator=monitored_session.ChiefSessionCreator( + checkpoint_filename_with_path=checkpoint_path)) + + feed_tensor_info = signature_def.inputs + self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name) + for k, v in feed_tensor_info.items()} + fetch_tensor_info = signature_def.outputs + self._fetch_tensors = {k: self._graph.get_tensor_by_name(v.name) + for k, v in fetch_tensor_info.items()} diff --git a/tensorflow/contrib/predictor/core_estimator_predictor_test.py b/tensorflow/contrib/predictor/core_estimator_predictor_test.py new file mode 100644 index 0000000000..4221086794 --- /dev/null +++ b/tensorflow/contrib/predictor/core_estimator_predictor_test.py @@ -0,0 +1,81 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.core_estimator_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tempfile +import numpy as np + +from tensorflow.contrib.predictor import core_estimator_predictor +from tensorflow.contrib.predictor import testing_common +from tensorflow.python.platform import test + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + + +class CoreEstimatorPredictorTest(test.TestCase): + """Test fixture for `CoreEstimatorPredictor`.""" + + def setUp(self): + model_dir = tempfile.mkdtemp() + self._estimator = testing_common.get_arithmetic_estimator( + core=True, model_dir=model_dir) + self._serving_input_receiver_fn = testing_common.get_arithmetic_input_fn( + core=True, train=False) + + def testDefault(self): + """Test prediction with default signature.""" + np.random.seed(1111) + x = np.random.rand() + y = np.random.rand() + predictor = core_estimator_predictor.CoreEstimatorPredictor( + estimator=self._estimator, + serving_input_receiver_fn=self._serving_input_receiver_fn) + output = predictor({'x': x, 'y': y})['sum'] + self.assertAlmostEqual(output, x + y, places=3) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signatures.""" + np.random.seed(1234) + for output_key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = core_estimator_predictor.CoreEstimatorPredictor( + estimator=self._estimator, + serving_input_receiver_fn=self._serving_input_receiver_fn, + output_key=output_key) + output_tensor_name = predictor.fetch_tensors[output_key].name + self.assertRegexpMatches( + output_tensor_name, + output_key, + msg='Unexpected fetch tensor.') + output = predictor({'x': x, 'y': y})[output_key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for output key "{}." ' + 'Got output {} for x = {} and y = {}'.format( + output_key, output, x, y)) + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/predictor.py b/tensorflow/contrib/predictor/predictor.py new file mode 100644 index 0000000000..dbc0028259 --- /dev/null +++ b/tensorflow/contrib/predictor/predictor.py @@ -0,0 +1,77 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Abstract base class for all predictors.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc +import six + + +@six.add_metaclass(abc.ABCMeta) +class Predictor(object): + """Abstract base class for all predictors.""" + + @property + def graph(self): + return self._graph + + @property + def session(self): + return self._session + + @property + def feed_tensors(self): + return self._feed_tensors + + @property + def fetch_tensors(self): + return self._fetch_tensors + + def __repr__(self): + return '{} with feed tensors {} and fetch_tensors {}'.format( + type(self).__name__, self._feed_tensors, self._fetch_tensors) + + def __call__(self, input_dict): + """Returns predictions based on `input_dict`. + + Args: + input_dict: a `dict` mapping strings to numpy arrays. These keys + must match `self._feed_tensors.keys()`. + + Returns: + A `dict` mapping strings to numpy arrays. The keys match + `self.fetch_tensors.keys()`. + + Raises: + ValueError: `input_dict` does not match `feed_tensors`. + """ + # TODO(jamieas): make validation optional? + input_keys = set(input_dict.keys()) + expected_keys = set(self.feed_tensors.keys()) + unexpected_keys = input_keys - expected_keys + if unexpected_keys: + raise ValueError('Got unexpected keys in input_dict: {}'.format( + unexpected_keys)) + + feed_dict = {} + for key in self.feed_tensors.keys(): + value = input_dict.get(key) + if value is not None: + feed_dict[self.feed_tensors[key]] = value + return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict) diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py new file mode 100644 index 0000000000..e3f30d917d --- /dev/null +++ b/tensorflow/contrib/predictor/predictor_factories.py @@ -0,0 +1,132 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Factory functions for `Predictor`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.predictor import contrib_estimator_predictor +from tensorflow.contrib.predictor import core_estimator_predictor +from tensorflow.contrib.predictor import saved_model_predictor +from tensorflow.python.estimator import estimator as core_estimator + + +def from_contrib_estimator(estimator, + prediction_input_fn, + input_alternative_key=None, + output_alternative_key=None, + graph=None): + """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`. + + Args: + estimator: an instance of `tf.contrib.learn.Estimator`. + prediction_input_fn: a function that takes no arguments and returns an + instance of `InputFnOps`. + input_alternative_key: Optional. Specify the input alternative used for + prediction. + output_alternative_key: Specify the output alternative used for + prediction. Not needed for single-headed models but required for + multi-headed models. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + TypeError: if `estimator` is a core `Estimator` instead of a contrib + `Estimator`. + """ + if isinstance(estimator, core_estimator.Estimator): + raise TypeError('Espected estimator to be of type ' + 'tf.contrib.learn.Estimator, but got type ' + 'tf.python.estimator.Estimator. You likely want to call ' + 'from_estimator.') + return contrib_estimator_predictor.ContribEstimatorPredictor( + estimator, + prediction_input_fn, + input_alternative_key, + output_alternative_key, + graph) + + +def from_estimator(estimator, + serving_input_receiver_fn, + output_key=None, + graph=None): + """Constructs a `Predictor` from a `tf.python.estimator.Estimator`. + + Args: + estimator: an instance of `learn.python.estimator.Estimator`. + serving_input_receiver_fn: a function that takes no arguments and returns + an instance of `ServingInputReceiver` compatible with `estimator`. + output_key: Optional string specifying the export output to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + TypeError: if `estimator` is a contrib `Estimator` instead of a core + `Estimator`. + """ + if isinstance(estimator, estimator.Estimator): + raise TypeError('Espected estimator to be of type ' + 'tf.python.estimator.Estimator, but got type ' + 'tf.contrib.learn.Estimator. You likely want to call ' + 'from_contrib_estimator.') + return core_estimator_predictor.CoreEstimatorPredictor( + estimator, + serving_input_receiver_fn, + output_key, + graph) + + +def from_saved_model(export_dir, + signature_def_key=None, + signature_def=None, + tags=None, + graph=None): + """Constructs a `Predictor` from a `SavedModel` on disk. + + Args: + export_dir: a path to a directory containing a `SavedModel`. + signature_def_key: Optional string specifying the signature to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of + `signature_def_key` and `signature_def` + signature_def: A `SignatureDef` proto specifying the inputs and outputs + for prediction. Only one of `signature_def_key` and `signature_def` + should be specified. + tags: Optional. Tags that will be used to retrieve the correct + `SignatureDef`. Defaults to `DEFAULT_TAGS`. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + + Returns: + An initialized `Predictor`. + + Raises: + ValueError: More than one of `signature_def_key` and `signature_def` is + specified. + """ + return saved_model_predictor.SavedModelPredictor(export_dir, + signature_def_key, + signature_def, + tags, + graph) diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py new file mode 100644 index 0000000000..ab2bafa0c8 --- /dev/null +++ b/tensorflow/contrib/predictor/saved_model_predictor.py @@ -0,0 +1,143 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""A `Predictor` constructed from a `SavedModel`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import logging + +from tensorflow.contrib.predictor import predictor +from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils +from tensorflow.python.client import session +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.tools import saved_model_cli + + +DEFAULT_TAGS = 'serve' + +_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}' + + +def _get_signature_def(signature_def_key, export_dir, tags): + """Construct a `SignatureDef` proto.""" + signature_def_key = ( + signature_def_key or + signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY) + + metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags) + + try: + signature_def = signature_def_utils.get_signature_def_by_key( + metagraph_def, + signature_def_key) + except ValueError as e: + try: + formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format( + signature_def_key) + signature_def = signature_def_utils.get_signature_def_by_key( + metagraph_def, formatted_key) + + logging.warning('Could not find signature def "%s". ' + 'Using "%s" instead', signature_def_key, formatted_key) + except ValueError: + raise ValueError( + 'Got signature_def_key "{}". Available signatures are {}. ' + 'Original error:\n{}'.format( + signature_def_key, list(metagraph_def.signature_def), e)) + return signature_def + + +def _check_signature_arguments(signature_def_key, + signature_def, + input_names, + output_names): + """Validates signature arguments for `SavedModelPredictor`.""" + signature_def_key_specified = signature_def_key is not None + signature_def_specified = signature_def is not None + input_names_specified = input_names is not None + output_names_specified = output_names is not None + if input_names_specified != output_names_specified: + raise ValueError( + 'input_names and output_names must both be specified or both be ' + 'unspecified.' + ) + + if (signature_def_key_specified + signature_def_specified + + input_names_specified > 1): + raise ValueError( + 'You must specify at most one of signature_def_key OR signature_def OR' + '(input_names AND output_names).' + ) + + +class SavedModelPredictor(predictor.Predictor): + """A `Predictor` constructed from a `SavedModel`.""" + + def __init__(self, + export_dir, + signature_def_key=None, + signature_def=None, + input_names=None, + output_names=None, + tags=None, + graph=None): + """Initialize a `CoreEstimatorPredictor`. + + Args: + export_dir: a path to a directory containing a `SavedModel`. + signature_def_key: Optional string specifying the signature to use. If + `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of + `signature_def_key` and `signature_def` should be specified. + signature_def: A `SignatureDef` proto specifying the inputs and outputs + for prediction. Only one of `signature_def_key` and `signature_def` + should be specified. + input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel` + that represent the input. The keys can be any string of the user's + choosing. + output_names: A dictionary mapping strings to `Tensor`s in the + `SavedModel` that represent the output. The keys can be any string of + the user's choosing. + tags: Optional. Tags that will be used to retrieve the correct + `SignatureDef`. Defaults to `DEFAULT_TAGS`. + graph: Optional. The Tensorflow `graph` in which prediction should be + done. + Raises: + ValueError: If more than one of signature_def_key OR signature_def OR + (input_names AND output_names) is specified. + """ + _check_signature_arguments( + signature_def_key, signature_def, input_names, output_names) + tags = tags or DEFAULT_TAGS + self._graph = graph or ops.Graph() + + with self._graph.as_default(): + self._session = session.Session() + loader.load(self._session, tags.split(','), export_dir) + + if input_names is None: + if signature_def is None: + signature_def = _get_signature_def(signature_def_key, export_dir, tags) + input_names = {k: v.name for k, v in signature_def.inputs.items()} + output_names = {k: v.name for k, v in signature_def.outputs.items()} + + self._feed_tensors = {k: self._graph.get_tensor_by_name(v) + for k, v in input_names.items()} + self._fetch_tensors = {k: self._graph.get_tensor_by_name(v) + for k, v in output_names.items()} diff --git a/tensorflow/contrib/predictor/saved_model_predictor_test.py b/tensorflow/contrib/predictor/saved_model_predictor_test.py new file mode 100644 index 0000000000..f40e2e73d9 --- /dev/null +++ b/tensorflow/contrib/predictor/saved_model_predictor_test.py @@ -0,0 +1,170 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for predictor.saved_model_predictor.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.predictor import saved_model_predictor +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import test +from tensorflow.python.saved_model import signature_def_utils + + +KEYS_AND_OPS = (('sum', lambda x, y: x + y), + ('product', lambda x, y: x * y,), + ('difference', lambda x, y: x - y)) + +MODEL_DIR_NAME = 'contrib/predictor/test_export_dir' + + +class SavedModelPredictorTest(test.TestCase): + + @classmethod + def setUpClass(cls): + # Load a saved model exported from the arithmetic `Estimator`. + # See `testing_common.py`. + cls._export_dir = test.test_src_dir_path(MODEL_DIR_NAME) + + def testDefault(self): + """Test prediction with default signature.""" + np.random.seed(1111) + x = np.random.rand() + y = np.random.rand() + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir) + output = predictor({'x': x, 'y': y})['outputs'] + self.assertAlmostEqual(output, x + y, places=3) + + def testSpecifiedSignatureKey(self): + """Test prediction with spedicified signature key.""" + np.random.seed(1234) + for signature_def_key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + signature_def_key=signature_def_key) + + output_tensor_name = predictor.fetch_tensors['outputs'].name + self.assertRegexpMatches( + output_tensor_name, + signature_def_key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})['outputs'] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}." ' + 'Got output {} for x = {} and y = {}'.format( + signature_def_key, output, x, y)) + + def testSpecifiedSignature(self): + """Test prediction with spedicified signature definition.""" + np.random.seed(4444) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + + inputs = { + 'x': meta_graph_pb2.TensorInfo( + name='inputs/x:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto()), + 'y': meta_graph_pb2.TensorInfo( + name='inputs/y:0', + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + outputs = { + key: meta_graph_pb2.TensorInfo( + name='outputs/{}:0'.format(key), + dtype=types_pb2.DT_FLOAT, + tensor_shape=tensor_shape_pb2.TensorShapeProto())} + signature_def = signature_def_utils.build_signature_def( + inputs=inputs, + outputs=outputs, + method_name='tensorflow/serving/regress') + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + signature_def=signature_def) + + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}". ' + 'Got output {} for x = {} and y = {}'.format(key, output, x, y)) + + def testSpecifiedTensors(self): + """Test prediction with spedicified `Tensor`s.""" + np.random.seed(987) + for key, op in KEYS_AND_OPS: + x = np.random.rand() + y = np.random.rand() + expected_output = op(x, y) + input_names = {'x': 'inputs/x:0', + 'y': 'inputs/y:0'} + output_names = {key: 'outputs/{}:0'.format(key)} + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + input_names=input_names, + output_names=output_names) + + output_tensor_name = predictor.fetch_tensors[key].name + self.assertRegexpMatches( + output_tensor_name, + key, + msg='Unexpected fetch tensor.') + + output = predictor({'x': x, 'y': y})[key] + self.assertAlmostEqual( + expected_output, output, places=3, + msg='Failed for signature "{}". ' + 'Got output {} for x = {} and y = {}'.format(key, output, x, y)) + + def testBadTagsFail(self): + """Test that predictor construction fails for bad tags.""" + bad_tags_regex = ('.* could not be found in SavedModel') + with self.assertRaisesRegexp(RuntimeError, bad_tags_regex): + _ = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + tags=('zomg, bad, tags')) + + def testSpecifiedGraph(self): + """Test that the predictor remembers a specified `Graph`.""" + g = ops.Graph() + predictor = saved_model_predictor.SavedModelPredictor( + export_dir=self._export_dir, + graph=g) + self.assertEqual(predictor.graph, g) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/predictor/test_export_dir/saved_model.pb b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb Binary files differnew file mode 100644 index 0000000000..9100fefb72 --- /dev/null +++ b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 Binary files differnew file mode 100644 index 0000000000..1b1cb4d44c --- /dev/null +++ b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.index b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index Binary files differnew file mode 100644 index 0000000000..dd32e9b71b --- /dev/null +++ b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index diff --git a/tensorflow/contrib/predictor/testing_common.py b/tensorflow/contrib/predictor/testing_common.py new file mode 100644 index 0000000000..1767704b99 --- /dev/null +++ b/tensorflow/contrib/predictor/testing_common.py @@ -0,0 +1,102 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Common code used for testing `Predictor`s.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator +from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.estimator import estimator as core_estimator +from tensorflow.python.estimator import model_fn +from tensorflow.python.estimator.export import export_lib +from tensorflow.python.estimator.export import export_output +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.saved_model import signature_constants + + +def get_arithmetic_estimator(core=True, model_dir=None): + """Returns an `Estimator` that performs basic arithmetic. + + Args: + core: if `True`, returns a `tensorflow.python.estimator.Estimator`. + Otherwise, returns a `tensorflow.contrib.learn.Estimator`. + model_dir: directory in which to export checkpoints and saved models. + Returns: + An `Estimator` that performs arithmetic operations on its inputs. + """ + def _model_fn(features, labels, mode): + _ = labels + x = features['x'] + y = features['y'] + with ops.name_scope('outputs'): + predictions = {'sum': math_ops.add(x, y, name='sum'), + 'product': math_ops.multiply(x, y, name='product'), + 'difference': math_ops.subtract(x, y, name='difference')} + if core: + export_outputs = {k: export_output.PredictOutput({k: v}) + for k, v in predictions.items()} + export_outputs[signature_constants. + DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['sum'] + return model_fn.EstimatorSpec(mode=mode, + predictions=predictions, + export_outputs=export_outputs, + loss=constant_op.constant(0), + train_op=control_flow_ops.no_op()) + else: + output_alternatives = {k: (constants.ProblemType.UNSPECIFIED, {k: v}) + for k, v in predictions.items()} + return contrib_model_fn.ModelFnOps( + mode=mode, + predictions=predictions, + output_alternatives=output_alternatives, + loss=constant_op.constant(0), + train_op=control_flow_ops.no_op()) + if core: + return core_estimator.Estimator(_model_fn) + else: + return contrib_estimator.Estimator(_model_fn, model_dir=model_dir) + + +def get_arithmetic_input_fn(core=True, train=False): + """Returns a input functions or serving input receiver function.""" + def _input_fn(): + with ops.name_scope('inputs'): + x = array_ops.placeholder_with_default(0.0, shape=[], name='x') + y = array_ops.placeholder_with_default(0.0, shape=[], name='y') + label = constant_op.constant(0.0) + features = {'x': x, 'y': y} + if core: + if train: + return features, label + return export_lib.ServingInputReceiver( + features=features, + receiver_tensors=features) + else: + if train: + return features, label + return input_fn_utils.InputFnOps( + features=features, + labels={}, + default_inputs=features) + return _input_fn |