aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/predictor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-06-15 12:53:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-15 12:57:26 -0700
commitfa75f26351f42e4fd3fc89b553d7919a6f147e41 (patch)
treed40a605c1e4ab22381c680897c0ead646a86ac3c /tensorflow/contrib/predictor
parentfa927634cbda0b7826e516336fb8277707fb6fe3 (diff)
Introduce Predictor, an interface for efficient, repeated inference.
PiperOrigin-RevId: 159141010
Diffstat (limited to 'tensorflow/contrib/predictor')
-rw-r--r--tensorflow/contrib/predictor/BUILD163
-rw-r--r--tensorflow/contrib/predictor/README.md96
-rw-r--r--tensorflow/contrib/predictor/__init__.py24
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor.py74
-rw-r--r--tensorflow/contrib/predictor/contrib_estimator_predictor_test.py70
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor.py80
-rw-r--r--tensorflow/contrib/predictor/core_estimator_predictor_test.py81
-rw-r--r--tensorflow/contrib/predictor/predictor.py77
-rw-r--r--tensorflow/contrib/predictor/predictor_factories.py132
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor.py143
-rw-r--r--tensorflow/contrib/predictor/saved_model_predictor_test.py170
-rw-r--r--tensorflow/contrib/predictor/test_export_dir/saved_model.pbbin0 -> 7736 bytes
-rw-r--r--tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001bin0 -> 8 bytes
-rw-r--r--tensorflow/contrib/predictor/test_export_dir/variables/variables.indexbin0 -> 127 bytes
-rw-r--r--tensorflow/contrib/predictor/testing_common.py102
15 files changed, 1212 insertions, 0 deletions
diff --git a/tensorflow/contrib/predictor/BUILD b/tensorflow/contrib/predictor/BUILD
new file mode 100644
index 0000000000..c4b46551c1
--- /dev/null
+++ b/tensorflow/contrib/predictor/BUILD
@@ -0,0 +1,163 @@
+# `Predictor` classes provide an interface for efficient, repeated inference.
+
+package(default_visibility = ["//third_party/tensroflow/contrib/predictor:__subpackages__"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
+
+py_library(
+ name = "predictor",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":predictor_factories",
+ ],
+)
+
+py_library(
+ name = "predictor_factories",
+ srcs = [
+ "predictor_factories.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":contrib_estimator_predictor",
+ ":core_estimator_predictor",
+ ":saved_model_predictor",
+ "//tensorflow/contrib/learn",
+ ],
+)
+
+py_library(
+ name = "base_predictor",
+ srcs = [
+ "predictor.py",
+ ],
+ srcs_version = "PY2AND3",
+)
+
+py_library(
+ name = "saved_model_predictor",
+ srcs = [
+ "saved_model_predictor.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base_predictor",
+ "//tensorflow/python/tools:saved_model_cli",
+ ],
+)
+
+py_library(
+ name = "core_estimator_predictor",
+ srcs = [
+ "core_estimator_predictor.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base_predictor",
+ "//tensorflow/contrib/learn",
+ ],
+)
+
+py_library(
+ name = "contrib_estimator_predictor",
+ srcs = [
+ "contrib_estimator_predictor.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":base_predictor",
+ "//tensorflow/contrib/learn",
+ ],
+)
+
+py_library(
+ name = "testing_common",
+ srcs = [
+ "testing_common.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ],
+)
+
+# Transitive dependencies of this target will be included in the pip package.
+py_library(
+ name = "predictor_pip",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":contrib_estimator_predictor",
+ ":core_estimator_predictor",
+ ":saved_model_predictor",
+ ],
+)
+
+py_test(
+ name = "saved_model_predictor_test",
+ srcs = [
+ "saved_model_predictor_test.py",
+ ],
+ data = [":test_export_dir"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":saved_model_predictor",
+ ":testing_common",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "core_estimator_predictor_test",
+ srcs = [
+ "core_estimator_predictor_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":core_estimator_predictor",
+ ":testing_common",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+py_test(
+ name = "contrib_estimator_predictor_test",
+ srcs = [
+ "contrib_estimator_predictor_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":contrib_estimator_predictor",
+ ":testing_common",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+filegroup(
+ name = "test_export_dir",
+ srcs = glob(["test_export_dir/**/*"]),
+ tags = ["nopip"],
+)
diff --git a/tensorflow/contrib/predictor/README.md b/tensorflow/contrib/predictor/README.md
new file mode 100644
index 0000000000..16cdcf3e70
--- /dev/null
+++ b/tensorflow/contrib/predictor/README.md
@@ -0,0 +1,96 @@
+# Predictors
+
+The `Predictor` classes provide a simple interface for performing repeated,
+efficient inference. A `Predictor` can be constructed from a `SavedModel` on
+disk, a `tf.Estimator` or a `tf.contrib.Estimator`.
+
+To facilitate the examples below, let's define a trivial `Estimator` that just
+calculates a sum:
+
+```python
+def model_fn(features, labels, mode):
+ z = tf.add(features['x'], features['y'], name='z')
+ return tf.contrib.learn.ModelFnOps(
+ mode, {'z': z}, loss=tf.constant(0.0), train_op=tf.no_op())
+
+estimator = tf.contrib.learn.Estimator(model_fn=model_fn)
+```
+
+We can then construct a `Predictor` in two different ways.
+
+## `Predictor` from a `SavedModel`
+
+Given a trained `Estimator`, we first export a `SavedModel`:
+
+```python
+def serving_input_fn():
+ x = tf.placeholder(dtype=tf.float32, shape=[None], name='x')
+ y = tf.placeholder(dtype=tf.float32, shape=[None], name='y')
+
+ features = {'x': x, 'y': y}
+ return tf.contrib.learn.utils.input_fn_utils.InputFnOps(
+ features, None, default_inputs=features)
+
+saved_model_dir = estimator.export_savedmodel(my_export_dir, serving_input_fn)
+```
+
+We can then construct a `Predictor` as follows:
+
+```python
+saved_model_predictor = predictor.from_saved_model(export_dir='test_export_dir')
+output_dict = saved_model_predictor({'x': [1.0], 'y': [5.2]})
+# output_dict == {'sum': [6.2]}
+```
+
+By specifying a signature definition, we can feed and fetch any `Tensor`s in
+the `Graph`. In this example, we feed and fetch the same `Tensor`, `z`:
+
+```python
+inputs = outputs = {'z': tf.TensorInfo(
+ name='z:0',
+ dtype=types_pb2.DT_FLOAT,
+ tensor_shape=tensor_shape_pb2.TensorShapeProto())}
+
+signature_def = tf.saved_model.signature_def_utils.build_signature_def(
+ inputs=inputs,
+ outputs=outputs,
+ method_name='tensorflow/serving/regress')
+
+trivial_predictor = predictor.from_saved_model(
+ export_dir=saved_model_dir,
+ signature_def=signature_def)
+
+output_dict = trivial_predictor({'z': [32.]})
+# output_dict == {'z': [32.]}
+```
+
+You can also specify input and output `Tensor`s by name using the `input_names`
+and `output_names` keywords:
+
+```python
+saved_model_predictor = predictor.from_saved_model(
+ export_dir=saved_model_dir,
+ input_names={'x': 'x:0', 'y': 'y:0'},
+ outputs={'z': 'z:0'})
+
+output_dict = saved_model_predictor({'x': [6.], 'y': [11.]})
+# output_dict == {'z': [17.]}
+```
+
+This functionality is particularly useful for performing encoding once, but
+doing multiple decoding iterations with e.g. seq2seq models.
+
+## `Predictor` from an `Estimator`
+
+We can also construct a `Predictor` directly from an `Estimator`. Defining
+`serving_input_fn` as above,
+
+```python
+estimator_predictor = predictor.from_contrib_estimator(
+ estimator, serving_input_fn)
+output_dict = sum_predictor({'x': [1., 2.], 'y': [3., 4.]})
+# output_dict == {'z': [4., 6.]}
+```
+
+Construction from a `tf.Estimator` is almost identical.
+
diff --git a/tensorflow/contrib/predictor/__init__.py b/tensorflow/contrib/predictor/__init__.py
new file mode 100644
index 0000000000..d270c3f798
--- /dev/null
+++ b/tensorflow/contrib/predictor/__init__.py
@@ -0,0 +1,24 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Modules for `Predictor`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.predictor import from_contrib_estimator
+from tensorflow.contrib.predictor import from_estimator
+from tensorflow.contrib.predictor import from_saved_model
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor.py b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
new file mode 100644
index 0000000000..b7a98c68e2
--- /dev/null
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor.py
@@ -0,0 +1,74 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A `Predictor constructed from a `tf.contrib.learn.Estimator`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.contrib.predictor import predictor
+from tensorflow.python.framework import ops
+from tensorflow.python.training import monitored_session
+from tensorflow.python.training import saver
+
+
+class ContribEstimatorPredictor(predictor.Predictor):
+ """A `Predictor constructed from a `tf.contrib.learn.Estimator`."""
+
+ def __init__(self,
+ estimator,
+ prediction_input_fn,
+ input_alternative_key=None,
+ output_alternative_key=None,
+ graph=None):
+ """Initialize a `ContribEstimatorPredictor`.
+
+ Args:
+ estimator: an instance of `tf.contrib.learn.Estimator`.
+ prediction_input_fn: a function that takes no arguments and returns an
+ instance of `InputFnOps`.
+ input_alternative_key: Optional. Specify the input alternative used for
+ prediction.
+ output_alternative_key: Specify the output alternative used for
+ prediction. Not needed for single-headed models but required for
+ multi-headed models.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+ """
+ self._graph = graph or ops.Graph()
+ with self._graph.as_default():
+ input_fn_ops = prediction_input_fn()
+ # pylint: disable=protected-access
+ model_fn_ops = estimator._get_predict_ops(input_fn_ops.features)
+ # pylint: enable=protected-access
+ checkpoint_path = saver.latest_checkpoint(estimator.model_dir)
+ self._session = monitored_session.MonitoredSession(
+ session_creator=monitored_session.ChiefSessionCreator(
+ checkpoint_filename_with_path=checkpoint_path))
+
+ input_alternative_key = (
+ input_alternative_key or
+ saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY)
+ input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
+ input_fn_ops)
+ self._feed_tensors = input_alternatives[input_alternative_key]
+
+ (output_alternatives,
+ output_alternative_key) = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, output_alternative_key)
+ _, fetch_tensors = output_alternatives[output_alternative_key]
+ self._fetch_tensors = fetch_tensors
diff --git a/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py
new file mode 100644
index 0000000000..4b97a52b1a
--- /dev/null
+++ b/tensorflow/contrib/predictor/contrib_estimator_predictor_test.py
@@ -0,0 +1,70 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for predictor.contrib_estimator_predictor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import numpy as np
+
+from tensorflow.contrib.predictor import contrib_estimator_predictor
+from tensorflow.contrib.predictor import testing_common
+from tensorflow.python.platform import test
+
+
+KEYS_AND_OPS = (('sum', lambda x, y: x + y),
+ ('product', lambda x, y: x * y,),
+ ('difference', lambda x, y: x - y))
+
+
+class ContribEstimatorPredictorTest(test.TestCase):
+ """Test fixture for `ContribEstimatorPredictor`."""
+
+ def setUp(self):
+ model_dir = tempfile.mkdtemp()
+ self._estimator = testing_common.get_arithmetic_estimator(
+ core=False, model_dir=model_dir)
+ self._prediction_input_fn = testing_common.get_arithmetic_input_fn(
+ core=False, train=False)
+
+ def testSpecifiedSignatureKey(self):
+ """Test prediction with spedicified signatures."""
+ np.random.seed(1234)
+ for key, op in KEYS_AND_OPS:
+ x = np.random.rand()
+ y = np.random.rand()
+ expected_output = op(x, y)
+
+ predictor = contrib_estimator_predictor.ContribEstimatorPredictor(
+ estimator=self._estimator,
+ prediction_input_fn=self._prediction_input_fn,
+ output_alternative_key=key)
+ output_tensor_name = predictor.fetch_tensors[key].name
+ self.assertRegexpMatches(
+ output_tensor_name,
+ key,
+ msg='Unexpected fetch tensor.')
+ output = predictor({'x': x, 'y': y})[key]
+ self.assertAlmostEqual(
+ expected_output, output, places=3,
+ msg='Failed for output key "{}." '
+ 'Got output {} for x = {} and y = {}'.format(
+ key, output, x, y))
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor.py b/tensorflow/contrib/predictor/core_estimator_predictor.py
new file mode 100644
index 0000000000..5557ef5101
--- /dev/null
+++ b/tensorflow/contrib/predictor/core_estimator_predictor.py
@@ -0,0 +1,80 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A `Predictor` constructed from an `learn.python.estimator.Estimator`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.predictor import predictor
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.training import monitored_session
+
+
+def _get_signature_def(
+ serving_input_receiver, estimator, output_key=None):
+ """Construct a `SignatureDef` proto."""
+ if output_key is None:
+ output_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+ # pylint: disable=protected-access
+ estimator_spec = estimator._call_model_fn(
+ serving_input_receiver.features, None, model_fn.ModeKeys.PREDICT)
+ # pylint: enable=protected-access
+ export_outputs = estimator_spec.export_outputs
+ export_output = export_outputs.get(output_key)
+ if export_output is None:
+ raise KeyError('output_key must be one of {}; got {}'.format(
+ export_outputs.keys(), output_key))
+ return export_output.as_signature_def(serving_input_receiver.receiver_tensors)
+
+
+class CoreEstimatorPredictor(predictor.Predictor):
+ """A `Predictor` constructed from an `learn.python.estimator.Estimator`."""
+
+ def __init__(self,
+ estimator,
+ serving_input_receiver_fn,
+ output_key=None,
+ graph=None):
+ """Initialize a `CoreEstimatorPredictor`.
+
+ Args:
+ estimator: an instance of `learn.python.estimator.Estimator`.
+ serving_input_receiver_fn: a function that takes no arguments and returns
+ an instance of `ServingInputReceiver` compatible with `estimator`.
+ output_key: Optional string specifying the export output to use. If
+ `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+ """
+ self._graph = graph or ops.Graph()
+ with self._graph.as_default():
+ serving_input_receiver = serving_input_receiver_fn()
+ signature_def = _get_signature_def(
+ serving_input_receiver, estimator, output_key)
+ checkpoint_path = estimator.model_dir
+ self._session = monitored_session.MonitoredSession(
+ session_creator=monitored_session.ChiefSessionCreator(
+ checkpoint_filename_with_path=checkpoint_path))
+
+ feed_tensor_info = signature_def.inputs
+ self._feed_tensors = {k: self._graph.get_tensor_by_name(v.name)
+ for k, v in feed_tensor_info.items()}
+ fetch_tensor_info = signature_def.outputs
+ self._fetch_tensors = {k: self._graph.get_tensor_by_name(v.name)
+ for k, v in fetch_tensor_info.items()}
diff --git a/tensorflow/contrib/predictor/core_estimator_predictor_test.py b/tensorflow/contrib/predictor/core_estimator_predictor_test.py
new file mode 100644
index 0000000000..4221086794
--- /dev/null
+++ b/tensorflow/contrib/predictor/core_estimator_predictor_test.py
@@ -0,0 +1,81 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for predictor.core_estimator_predictor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tempfile
+import numpy as np
+
+from tensorflow.contrib.predictor import core_estimator_predictor
+from tensorflow.contrib.predictor import testing_common
+from tensorflow.python.platform import test
+
+
+KEYS_AND_OPS = (('sum', lambda x, y: x + y),
+ ('product', lambda x, y: x * y,),
+ ('difference', lambda x, y: x - y))
+
+
+class CoreEstimatorPredictorTest(test.TestCase):
+ """Test fixture for `CoreEstimatorPredictor`."""
+
+ def setUp(self):
+ model_dir = tempfile.mkdtemp()
+ self._estimator = testing_common.get_arithmetic_estimator(
+ core=True, model_dir=model_dir)
+ self._serving_input_receiver_fn = testing_common.get_arithmetic_input_fn(
+ core=True, train=False)
+
+ def testDefault(self):
+ """Test prediction with default signature."""
+ np.random.seed(1111)
+ x = np.random.rand()
+ y = np.random.rand()
+ predictor = core_estimator_predictor.CoreEstimatorPredictor(
+ estimator=self._estimator,
+ serving_input_receiver_fn=self._serving_input_receiver_fn)
+ output = predictor({'x': x, 'y': y})['sum']
+ self.assertAlmostEqual(output, x + y, places=3)
+
+ def testSpecifiedSignatureKey(self):
+ """Test prediction with spedicified signatures."""
+ np.random.seed(1234)
+ for output_key, op in KEYS_AND_OPS:
+ x = np.random.rand()
+ y = np.random.rand()
+ expected_output = op(x, y)
+
+ predictor = core_estimator_predictor.CoreEstimatorPredictor(
+ estimator=self._estimator,
+ serving_input_receiver_fn=self._serving_input_receiver_fn,
+ output_key=output_key)
+ output_tensor_name = predictor.fetch_tensors[output_key].name
+ self.assertRegexpMatches(
+ output_tensor_name,
+ output_key,
+ msg='Unexpected fetch tensor.')
+ output = predictor({'x': x, 'y': y})[output_key]
+ self.assertAlmostEqual(
+ expected_output, output, places=3,
+ msg='Failed for output key "{}." '
+ 'Got output {} for x = {} and y = {}'.format(
+ output_key, output, x, y))
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/predictor/predictor.py b/tensorflow/contrib/predictor/predictor.py
new file mode 100644
index 0000000000..dbc0028259
--- /dev/null
+++ b/tensorflow/contrib/predictor/predictor.py
@@ -0,0 +1,77 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Abstract base class for all predictors."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class Predictor(object):
+ """Abstract base class for all predictors."""
+
+ @property
+ def graph(self):
+ return self._graph
+
+ @property
+ def session(self):
+ return self._session
+
+ @property
+ def feed_tensors(self):
+ return self._feed_tensors
+
+ @property
+ def fetch_tensors(self):
+ return self._fetch_tensors
+
+ def __repr__(self):
+ return '{} with feed tensors {} and fetch_tensors {}'.format(
+ type(self).__name__, self._feed_tensors, self._fetch_tensors)
+
+ def __call__(self, input_dict):
+ """Returns predictions based on `input_dict`.
+
+ Args:
+ input_dict: a `dict` mapping strings to numpy arrays. These keys
+ must match `self._feed_tensors.keys()`.
+
+ Returns:
+ A `dict` mapping strings to numpy arrays. The keys match
+ `self.fetch_tensors.keys()`.
+
+ Raises:
+ ValueError: `input_dict` does not match `feed_tensors`.
+ """
+ # TODO(jamieas): make validation optional?
+ input_keys = set(input_dict.keys())
+ expected_keys = set(self.feed_tensors.keys())
+ unexpected_keys = input_keys - expected_keys
+ if unexpected_keys:
+ raise ValueError('Got unexpected keys in input_dict: {}'.format(
+ unexpected_keys))
+
+ feed_dict = {}
+ for key in self.feed_tensors.keys():
+ value = input_dict.get(key)
+ if value is not None:
+ feed_dict[self.feed_tensors[key]] = value
+ return self._session.run(fetches=self.fetch_tensors, feed_dict=feed_dict)
diff --git a/tensorflow/contrib/predictor/predictor_factories.py b/tensorflow/contrib/predictor/predictor_factories.py
new file mode 100644
index 0000000000..e3f30d917d
--- /dev/null
+++ b/tensorflow/contrib/predictor/predictor_factories.py
@@ -0,0 +1,132 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Factory functions for `Predictor`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.predictor import contrib_estimator_predictor
+from tensorflow.contrib.predictor import core_estimator_predictor
+from tensorflow.contrib.predictor import saved_model_predictor
+from tensorflow.python.estimator import estimator as core_estimator
+
+
+def from_contrib_estimator(estimator,
+ prediction_input_fn,
+ input_alternative_key=None,
+ output_alternative_key=None,
+ graph=None):
+ """Constructs a `Predictor` from a `tf.contrib.learn.Estimator`.
+
+ Args:
+ estimator: an instance of `tf.contrib.learn.Estimator`.
+ prediction_input_fn: a function that takes no arguments and returns an
+ instance of `InputFnOps`.
+ input_alternative_key: Optional. Specify the input alternative used for
+ prediction.
+ output_alternative_key: Specify the output alternative used for
+ prediction. Not needed for single-headed models but required for
+ multi-headed models.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+
+ Returns:
+ An initialized `Predictor`.
+
+ Raises:
+ TypeError: if `estimator` is a core `Estimator` instead of a contrib
+ `Estimator`.
+ """
+ if isinstance(estimator, core_estimator.Estimator):
+ raise TypeError('Espected estimator to be of type '
+ 'tf.contrib.learn.Estimator, but got type '
+ 'tf.python.estimator.Estimator. You likely want to call '
+ 'from_estimator.')
+ return contrib_estimator_predictor.ContribEstimatorPredictor(
+ estimator,
+ prediction_input_fn,
+ input_alternative_key,
+ output_alternative_key,
+ graph)
+
+
+def from_estimator(estimator,
+ serving_input_receiver_fn,
+ output_key=None,
+ graph=None):
+ """Constructs a `Predictor` from a `tf.python.estimator.Estimator`.
+
+ Args:
+ estimator: an instance of `learn.python.estimator.Estimator`.
+ serving_input_receiver_fn: a function that takes no arguments and returns
+ an instance of `ServingInputReceiver` compatible with `estimator`.
+ output_key: Optional string specifying the export output to use. If
+ `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+
+ Returns:
+ An initialized `Predictor`.
+
+ Raises:
+ TypeError: if `estimator` is a contrib `Estimator` instead of a core
+ `Estimator`.
+ """
+ if isinstance(estimator, estimator.Estimator):
+ raise TypeError('Espected estimator to be of type '
+ 'tf.python.estimator.Estimator, but got type '
+ 'tf.contrib.learn.Estimator. You likely want to call '
+ 'from_contrib_estimator.')
+ return core_estimator_predictor.CoreEstimatorPredictor(
+ estimator,
+ serving_input_receiver_fn,
+ output_key,
+ graph)
+
+
+def from_saved_model(export_dir,
+ signature_def_key=None,
+ signature_def=None,
+ tags=None,
+ graph=None):
+ """Constructs a `Predictor` from a `SavedModel` on disk.
+
+ Args:
+ export_dir: a path to a directory containing a `SavedModel`.
+ signature_def_key: Optional string specifying the signature to use. If
+ `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
+ `signature_def_key` and `signature_def`
+ signature_def: A `SignatureDef` proto specifying the inputs and outputs
+ for prediction. Only one of `signature_def_key` and `signature_def`
+ should be specified.
+ tags: Optional. Tags that will be used to retrieve the correct
+ `SignatureDef`. Defaults to `DEFAULT_TAGS`.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+
+ Returns:
+ An initialized `Predictor`.
+
+ Raises:
+ ValueError: More than one of `signature_def_key` and `signature_def` is
+ specified.
+ """
+ return saved_model_predictor.SavedModelPredictor(export_dir,
+ signature_def_key,
+ signature_def,
+ tags,
+ graph)
diff --git a/tensorflow/contrib/predictor/saved_model_predictor.py b/tensorflow/contrib/predictor/saved_model_predictor.py
new file mode 100644
index 0000000000..ab2bafa0c8
--- /dev/null
+++ b/tensorflow/contrib/predictor/saved_model_predictor.py
@@ -0,0 +1,143 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""A `Predictor` constructed from a `SavedModel`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import logging
+
+from tensorflow.contrib.predictor import predictor
+from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
+from tensorflow.python.client import session
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.tools import saved_model_cli
+
+
+DEFAULT_TAGS = 'serve'
+
+_DEFAULT_INPUT_ALTERNATIVE_FORMAT = 'default_input_alternative:{}'
+
+
+def _get_signature_def(signature_def_key, export_dir, tags):
+ """Construct a `SignatureDef` proto."""
+ signature_def_key = (
+ signature_def_key or
+ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY)
+
+ metagraph_def = saved_model_cli.get_meta_graph_def(export_dir, tags)
+
+ try:
+ signature_def = signature_def_utils.get_signature_def_by_key(
+ metagraph_def,
+ signature_def_key)
+ except ValueError as e:
+ try:
+ formatted_key = _DEFAULT_INPUT_ALTERNATIVE_FORMAT.format(
+ signature_def_key)
+ signature_def = signature_def_utils.get_signature_def_by_key(
+ metagraph_def, formatted_key)
+
+ logging.warning('Could not find signature def "%s". '
+ 'Using "%s" instead', signature_def_key, formatted_key)
+ except ValueError:
+ raise ValueError(
+ 'Got signature_def_key "{}". Available signatures are {}. '
+ 'Original error:\n{}'.format(
+ signature_def_key, list(metagraph_def.signature_def), e))
+ return signature_def
+
+
+def _check_signature_arguments(signature_def_key,
+ signature_def,
+ input_names,
+ output_names):
+ """Validates signature arguments for `SavedModelPredictor`."""
+ signature_def_key_specified = signature_def_key is not None
+ signature_def_specified = signature_def is not None
+ input_names_specified = input_names is not None
+ output_names_specified = output_names is not None
+ if input_names_specified != output_names_specified:
+ raise ValueError(
+ 'input_names and output_names must both be specified or both be '
+ 'unspecified.'
+ )
+
+ if (signature_def_key_specified + signature_def_specified +
+ input_names_specified > 1):
+ raise ValueError(
+ 'You must specify at most one of signature_def_key OR signature_def OR'
+ '(input_names AND output_names).'
+ )
+
+
+class SavedModelPredictor(predictor.Predictor):
+ """A `Predictor` constructed from a `SavedModel`."""
+
+ def __init__(self,
+ export_dir,
+ signature_def_key=None,
+ signature_def=None,
+ input_names=None,
+ output_names=None,
+ tags=None,
+ graph=None):
+ """Initialize a `CoreEstimatorPredictor`.
+
+ Args:
+ export_dir: a path to a directory containing a `SavedModel`.
+ signature_def_key: Optional string specifying the signature to use. If
+ `None`, then `DEFAULT_SERVING_SIGNATURE_DEF_KEY` is used. Only one of
+ `signature_def_key` and `signature_def` should be specified.
+ signature_def: A `SignatureDef` proto specifying the inputs and outputs
+ for prediction. Only one of `signature_def_key` and `signature_def`
+ should be specified.
+ input_names: A dictionary mapping strings to `Tensor`s in the `SavedModel`
+ that represent the input. The keys can be any string of the user's
+ choosing.
+ output_names: A dictionary mapping strings to `Tensor`s in the
+ `SavedModel` that represent the output. The keys can be any string of
+ the user's choosing.
+ tags: Optional. Tags that will be used to retrieve the correct
+ `SignatureDef`. Defaults to `DEFAULT_TAGS`.
+ graph: Optional. The Tensorflow `graph` in which prediction should be
+ done.
+ Raises:
+ ValueError: If more than one of signature_def_key OR signature_def OR
+ (input_names AND output_names) is specified.
+ """
+ _check_signature_arguments(
+ signature_def_key, signature_def, input_names, output_names)
+ tags = tags or DEFAULT_TAGS
+ self._graph = graph or ops.Graph()
+
+ with self._graph.as_default():
+ self._session = session.Session()
+ loader.load(self._session, tags.split(','), export_dir)
+
+ if input_names is None:
+ if signature_def is None:
+ signature_def = _get_signature_def(signature_def_key, export_dir, tags)
+ input_names = {k: v.name for k, v in signature_def.inputs.items()}
+ output_names = {k: v.name for k, v in signature_def.outputs.items()}
+
+ self._feed_tensors = {k: self._graph.get_tensor_by_name(v)
+ for k, v in input_names.items()}
+ self._fetch_tensors = {k: self._graph.get_tensor_by_name(v)
+ for k, v in output_names.items()}
diff --git a/tensorflow/contrib/predictor/saved_model_predictor_test.py b/tensorflow/contrib/predictor/saved_model_predictor_test.py
new file mode 100644
index 0000000000..f40e2e73d9
--- /dev/null
+++ b/tensorflow/contrib/predictor/saved_model_predictor_test.py
@@ -0,0 +1,170 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for predictor.saved_model_predictor."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.predictor import saved_model_predictor
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import signature_def_utils
+
+
+KEYS_AND_OPS = (('sum', lambda x, y: x + y),
+ ('product', lambda x, y: x * y,),
+ ('difference', lambda x, y: x - y))
+
+MODEL_DIR_NAME = 'contrib/predictor/test_export_dir'
+
+
+class SavedModelPredictorTest(test.TestCase):
+
+ @classmethod
+ def setUpClass(cls):
+ # Load a saved model exported from the arithmetic `Estimator`.
+ # See `testing_common.py`.
+ cls._export_dir = test.test_src_dir_path(MODEL_DIR_NAME)
+
+ def testDefault(self):
+ """Test prediction with default signature."""
+ np.random.seed(1111)
+ x = np.random.rand()
+ y = np.random.rand()
+ predictor = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir)
+ output = predictor({'x': x, 'y': y})['outputs']
+ self.assertAlmostEqual(output, x + y, places=3)
+
+ def testSpecifiedSignatureKey(self):
+ """Test prediction with spedicified signature key."""
+ np.random.seed(1234)
+ for signature_def_key, op in KEYS_AND_OPS:
+ x = np.random.rand()
+ y = np.random.rand()
+ expected_output = op(x, y)
+
+ predictor = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir,
+ signature_def_key=signature_def_key)
+
+ output_tensor_name = predictor.fetch_tensors['outputs'].name
+ self.assertRegexpMatches(
+ output_tensor_name,
+ signature_def_key,
+ msg='Unexpected fetch tensor.')
+
+ output = predictor({'x': x, 'y': y})['outputs']
+ self.assertAlmostEqual(
+ expected_output, output, places=3,
+ msg='Failed for signature "{}." '
+ 'Got output {} for x = {} and y = {}'.format(
+ signature_def_key, output, x, y))
+
+ def testSpecifiedSignature(self):
+ """Test prediction with spedicified signature definition."""
+ np.random.seed(4444)
+ for key, op in KEYS_AND_OPS:
+ x = np.random.rand()
+ y = np.random.rand()
+ expected_output = op(x, y)
+
+ inputs = {
+ 'x': meta_graph_pb2.TensorInfo(
+ name='inputs/x:0',
+ dtype=types_pb2.DT_FLOAT,
+ tensor_shape=tensor_shape_pb2.TensorShapeProto()),
+ 'y': meta_graph_pb2.TensorInfo(
+ name='inputs/y:0',
+ dtype=types_pb2.DT_FLOAT,
+ tensor_shape=tensor_shape_pb2.TensorShapeProto())}
+ outputs = {
+ key: meta_graph_pb2.TensorInfo(
+ name='outputs/{}:0'.format(key),
+ dtype=types_pb2.DT_FLOAT,
+ tensor_shape=tensor_shape_pb2.TensorShapeProto())}
+ signature_def = signature_def_utils.build_signature_def(
+ inputs=inputs,
+ outputs=outputs,
+ method_name='tensorflow/serving/regress')
+ predictor = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir,
+ signature_def=signature_def)
+
+ output_tensor_name = predictor.fetch_tensors[key].name
+ self.assertRegexpMatches(
+ output_tensor_name,
+ key,
+ msg='Unexpected fetch tensor.')
+
+ output = predictor({'x': x, 'y': y})[key]
+ self.assertAlmostEqual(
+ expected_output, output, places=3,
+ msg='Failed for signature "{}". '
+ 'Got output {} for x = {} and y = {}'.format(key, output, x, y))
+
+ def testSpecifiedTensors(self):
+ """Test prediction with spedicified `Tensor`s."""
+ np.random.seed(987)
+ for key, op in KEYS_AND_OPS:
+ x = np.random.rand()
+ y = np.random.rand()
+ expected_output = op(x, y)
+ input_names = {'x': 'inputs/x:0',
+ 'y': 'inputs/y:0'}
+ output_names = {key: 'outputs/{}:0'.format(key)}
+ predictor = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir,
+ input_names=input_names,
+ output_names=output_names)
+
+ output_tensor_name = predictor.fetch_tensors[key].name
+ self.assertRegexpMatches(
+ output_tensor_name,
+ key,
+ msg='Unexpected fetch tensor.')
+
+ output = predictor({'x': x, 'y': y})[key]
+ self.assertAlmostEqual(
+ expected_output, output, places=3,
+ msg='Failed for signature "{}". '
+ 'Got output {} for x = {} and y = {}'.format(key, output, x, y))
+
+ def testBadTagsFail(self):
+ """Test that predictor construction fails for bad tags."""
+ bad_tags_regex = ('.* could not be found in SavedModel')
+ with self.assertRaisesRegexp(RuntimeError, bad_tags_regex):
+ _ = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir,
+ tags=('zomg, bad, tags'))
+
+ def testSpecifiedGraph(self):
+ """Test that the predictor remembers a specified `Graph`."""
+ g = ops.Graph()
+ predictor = saved_model_predictor.SavedModelPredictor(
+ export_dir=self._export_dir,
+ graph=g)
+ self.assertEqual(predictor.graph, g)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/predictor/test_export_dir/saved_model.pb b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb
new file mode 100644
index 0000000000..9100fefb72
--- /dev/null
+++ b/tensorflow/contrib/predictor/test_export_dir/saved_model.pb
Binary files differ
diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001 b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001
new file mode 100644
index 0000000000..1b1cb4d44c
--- /dev/null
+++ b/tensorflow/contrib/predictor/test_export_dir/variables/variables.data-00000-of-00001
Binary files differ
diff --git a/tensorflow/contrib/predictor/test_export_dir/variables/variables.index b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index
new file mode 100644
index 0000000000..dd32e9b71b
--- /dev/null
+++ b/tensorflow/contrib/predictor/test_export_dir/variables/variables.index
Binary files differ
diff --git a/tensorflow/contrib/predictor/testing_common.py b/tensorflow/contrib/predictor/testing_common.py
new file mode 100644
index 0000000000..1767704b99
--- /dev/null
+++ b/tensorflow/contrib/predictor/testing_common.py
@@ -0,0 +1,102 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Common code used for testing `Predictor`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import estimator as contrib_estimator
+from tensorflow.contrib.learn.python.learn.estimators import model_fn as contrib_model_fn
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.estimator import estimator as core_estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.estimator.export import export_lib
+from tensorflow.python.estimator.export import export_output
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.saved_model import signature_constants
+
+
+def get_arithmetic_estimator(core=True, model_dir=None):
+ """Returns an `Estimator` that performs basic arithmetic.
+
+ Args:
+ core: if `True`, returns a `tensorflow.python.estimator.Estimator`.
+ Otherwise, returns a `tensorflow.contrib.learn.Estimator`.
+ model_dir: directory in which to export checkpoints and saved models.
+ Returns:
+ An `Estimator` that performs arithmetic operations on its inputs.
+ """
+ def _model_fn(features, labels, mode):
+ _ = labels
+ x = features['x']
+ y = features['y']
+ with ops.name_scope('outputs'):
+ predictions = {'sum': math_ops.add(x, y, name='sum'),
+ 'product': math_ops.multiply(x, y, name='product'),
+ 'difference': math_ops.subtract(x, y, name='difference')}
+ if core:
+ export_outputs = {k: export_output.PredictOutput({k: v})
+ for k, v in predictions.items()}
+ export_outputs[signature_constants.
+ DEFAULT_SERVING_SIGNATURE_DEF_KEY] = export_outputs['sum']
+ return model_fn.EstimatorSpec(mode=mode,
+ predictions=predictions,
+ export_outputs=export_outputs,
+ loss=constant_op.constant(0),
+ train_op=control_flow_ops.no_op())
+ else:
+ output_alternatives = {k: (constants.ProblemType.UNSPECIFIED, {k: v})
+ for k, v in predictions.items()}
+ return contrib_model_fn.ModelFnOps(
+ mode=mode,
+ predictions=predictions,
+ output_alternatives=output_alternatives,
+ loss=constant_op.constant(0),
+ train_op=control_flow_ops.no_op())
+ if core:
+ return core_estimator.Estimator(_model_fn)
+ else:
+ return contrib_estimator.Estimator(_model_fn, model_dir=model_dir)
+
+
+def get_arithmetic_input_fn(core=True, train=False):
+ """Returns a input functions or serving input receiver function."""
+ def _input_fn():
+ with ops.name_scope('inputs'):
+ x = array_ops.placeholder_with_default(0.0, shape=[], name='x')
+ y = array_ops.placeholder_with_default(0.0, shape=[], name='y')
+ label = constant_op.constant(0.0)
+ features = {'x': x, 'y': y}
+ if core:
+ if train:
+ return features, label
+ return export_lib.ServingInputReceiver(
+ features=features,
+ receiver_tensors=features)
+ else:
+ if train:
+ return features, label
+ return input_fn_utils.InputFnOps(
+ features=features,
+ labels={},
+ default_inputs=features)
+ return _input_fn