aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2016-11-29 11:16:34 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-29 11:26:20 -0800
commite4d9c5f20d28df94cc8d1245c685f03098b811c0 (patch)
tree1ed707a9485fe9bfe03b9dcba4750cf596c667a9
parentd753ed4eda5d291e0d88f0b5babf7df4b691a066 (diff)
Export SavedModel from tf.Learn.
Change: 140502426
-rw-r--r--tensorflow/contrib/learn/BUILD30
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/constants.py26
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py114
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py117
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/head.py39
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/linear.py33
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/model_fn.py35
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/random_forest.py17
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/svm.py18
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc.py205
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/gc_test.py120
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py97
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py248
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py228
-rw-r--r--tensorflow/python/saved_model/BUILD36
-rw-r--r--tensorflow/python/saved_model/example/saved_model_half_plus_two.py4
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py11
-rw-r--r--tensorflow/python/saved_model/signature_def_utils.py158
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_test.py156
-rw-r--r--tensorflow/python/saved_model/utils.py28
-rw-r--r--tensorflow/python/saved_model/utils_test.py30
23 files changed, 1707 insertions, 77 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 5b7a2d76d8..764971935f 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -21,6 +21,10 @@ py_library(
"//tensorflow/contrib/session_bundle:exporter",
"//tensorflow/contrib/tensor_forest:client_lib",
"//tensorflow/python:framework",
+ "//tensorflow/python/saved_model:builder",
+ "//tensorflow/python/saved_model:loader",
+ "//tensorflow/python/saved_model:signature_def_utils",
+ "//tensorflow/python/saved_model:tag_constants",
],
)
@@ -663,6 +667,32 @@ py_test(
)
py_test(
+ name = "gc_test",
+ size = "small",
+ srcs = ["python/learn/utils/gc_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
+ name = "saved_model_export_utils_test",
+ size = "small",
+ srcs = ["python/learn/utils/saved_model_export_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":learn",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ ],
+)
+
+py_test(
name = "stability_test",
size = "small",
srcs = ["python/learn/estimators/stability_test.py"],
diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py
new file mode 100644
index 0000000000..aee4541627
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py
@@ -0,0 +1,26 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Constants regarding Estimators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+class ProblemType(object):
+ UNSPECIFIED = 0
+ CLASSIFICATION = 1
+ LINEAR_REGRESSION = 2
+ LOGISTIC_REGRESSION = 3
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
index b821d842cd..98947cc6d4 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import optimizers
from tensorflow.contrib.learn.python.learn import evaluable
@@ -452,6 +453,22 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def model_dir(self):
return self._estimator.model_dir
diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
index 5a1bc931f8..256e074079 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py
@@ -26,6 +26,7 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib
from tensorflow.contrib.layers.python.layers import feature_column_ops
@@ -868,6 +869,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def model_dir(self):
return self._estimator.model_dir
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 6b5b241936..ea43fba6b3 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -38,6 +38,8 @@ from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
+from tensorflow.contrib.framework.python.framework import experimental
+from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import graph_actions
from tensorflow.contrib.learn.python.learn import metric_spec
@@ -51,14 +53,21 @@ from tensorflow.contrib.learn.python.learn.estimators import tensor_signature
from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import export
-
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.saved_model import builder as saved_model_builder
+from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import device_setter
from tensorflow.python.training import saver
+from tensorflow.python.util import compat
AS_ITERABLE_DATE = '2016-09-15'
@@ -553,13 +562,12 @@ class BaseEstimator(
use_deprecated_input_fn=use_deprecated_input_fn,
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
- # pylint: enable=protected-access
@abc.abstractproperty
def _get_train_ops(self, features, labels):
"""Method that builds model graph and returns trainer ops.
- Expected to be overriden by sub-classes that require custom support.
+ Expected to be overridden by sub-classes that require custom support.
Args:
features: `Tensor` or `dict` of `Tensor` objects.
@@ -1106,6 +1114,106 @@ class Estimator(BaseEstimator):
self._labels_info)
return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER)
+ @experimental
+ def export_savedmodel(
+ self, export_dir_base, input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ """Exports inference graph as a SavedModel into given dir.
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ input_fn: A function that takes no argument and
+ returns an `InputFnOps`.
+ default_output_alternative_key: the name of the head to serve when none is
+ specified.
+ assets_extra: A dict specifying how to populate the assets.extra directory
+ within the exported SavedModel. Each key should give the destination
+ path (including the filename) relative to the assets.extra directory.
+ The corresponding value gives the full path of the source file to be
+ copied. For example, the simple case of copying a single file without
+ renaming it is specified as
+ `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`.
+ as_text: whether to write the SavedModel proto in text format.
+ exports_to_keep: Number of exports to keep.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ ValueError: if an unrecognized export_type is requested.
+ """
+ if input_fn is None:
+ raise ValueError('input_fn must be defined.')
+
+ with ops.Graph().as_default() as g:
+ contrib_variables.create_global_step(g)
+
+ # Call the input_fn and collect the input alternatives.
+ input_ops = input_fn()
+ input_alternatives, features = (
+ saved_model_export_utils.get_input_alternatives(input_ops))
+
+ # Call the model_fn and collect the output alternatives.
+ model_fn_ops = self._call_model_fn(features, None,
+ model_fn_lib.ModeKeys.INFER)
+ output_alternatives, actual_default_output_alternative_key = (
+ saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, default_output_alternative_key))
+
+ # Build the SignatureDefs from all pairs of input and output signatures
+ signature_def_map = saved_model_export_utils.build_all_signature_defs(
+ input_alternatives, output_alternatives,
+ actual_default_output_alternative_key)
+
+ # Locate the latest checkpoint
+ # TODO(soergel): does it help that we know we have one from this step?
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % self._model_dir)
+
+ export_dir = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+
+ with tf_session.Session('') as session:
+ variables.initialize_local_variables()
+ data_flow_ops.initialize_all_tables()
+ saver_for_restore = saver.Saver(
+ variables.global_variables(),
+ sharded=True)
+ saver_for_restore.restore(session, checkpoint_path)
+
+ init_op = control_flow_ops.group(
+ variables.local_variables_initializer(),
+ data_flow_ops.initialize_all_tables())
+
+ # Perform the export
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+ builder.add_meta_graph_and_variables(
+ session, [tag_constants.SERVING],
+ signature_def_map=signature_def_map,
+ assets_collection=ops.get_collection(
+ ops.GraphKeys.ASSET_FILEPATHS),
+ legacy_init_op=init_op)
+ builder.save(as_text)
+
+ # Add the extra assets
+ if assets_extra:
+ assets_extra_path = os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra'))
+ for dest_relative, source in assets_extra.items():
+ dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
+ compat.as_bytes(dest_relative))
+ dest_path = os.path.dirname(dest_absolute)
+ gfile.MakeDirs(dest_path)
+ gfile.Copy(source, dest_absolute)
+
+ return export_dir
+
# For time of deprecation x,y from Estimator allow direct access.
# pylint: disable=protected-access
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 7d500bce19..49c917a3de 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -22,6 +22,7 @@ from __future__ import print_function
import functools
import itertools
import json
+import os
import tempfile
import numpy as np
@@ -33,6 +34,11 @@ from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn.estimators import _sklearn
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.framework import ops
+from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import tag_constants
+from tensorflow.python.util import compat
_BOSTON_INPUT_DIM = 13
@@ -92,6 +98,8 @@ def linear_model_fn(features, labels, mode):
tf.contrib.learn.ModeKeys.TRAIN,
tf.contrib.learn.ModeKeys.EVAL,
tf.contrib.learn.ModeKeys.INFER)
+ if isinstance(features, dict):
+ (_, features), = features.items()
prediction, loss = (
tf.contrib.learn.models.linear_regression_zero_init(features, labels)
)
@@ -131,6 +139,45 @@ def logistic_model_no_mode_fn(features, labels):
learning_rate=0.1)
return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op
+VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n'
+EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n'
+
+
+def _build_estimator_for_export_tests(tmpdir):
+ def _input_fn():
+ iris = tf.contrib.learn.datasets.load_iris()
+ return {
+ 'feature': tf.constant(iris.data, dtype=tf.float32)
+ }, tf.constant(iris.target, shape=[150], dtype=tf.int32)
+
+ feature_columns = [tf.contrib.layers.real_valued_column('feature',
+ dimension=4)]
+
+ est = tf.contrib.learn.LinearRegressor(feature_columns)
+ est.fit(input_fn=_input_fn, steps=20)
+
+ feature_spec = tf.contrib.layers.create_feature_spec_for_parsing(
+ feature_columns)
+ export_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec)
+
+ # hack in an op that uses an asset, in order to test asset export.
+ # this is not actually valid, of course.
+ def export_input_fn_with_asset():
+ features, labels, inputs = export_input_fn()
+
+ vocab_file_name = os.path.join(tmpdir, 'my_vocab_file')
+ vocab_file = tf.gfile.GFile(vocab_file_name, mode='w')
+ vocab_file.write(VOCAB_FILE_CONTENT)
+ vocab_file.close()
+ hashtable = tf.contrib.lookup.HashTable(
+ tf.contrib.lookup.TextFileStringTableInitializer(vocab_file_name), 'x')
+ features['bogus_lookup'] = hashtable.lookup(
+ tf.to_int64(features['feature']))
+
+ return input_fn_utils.InputFnOps(features, labels, inputs)
+
+ return est, export_input_fn_with_asset
+
class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor):
@@ -503,6 +550,76 @@ class EstimatorTest(tf.test.TestCase):
self.assertEquals(expected, actual)
+ def test_export_savedmodel(self):
+ tmpdir = tempfile.mkdtemp()
+ est, export_input_fn = _build_estimator_for_export_tests(tmpdir)
+
+ extra_file_name = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('my_extra_file'))
+ extra_file = tf.gfile.GFile(extra_file_name, mode='w')
+ extra_file.write(EXTRA_FILE_CONTENT)
+ extra_file.close()
+ assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
+
+ export_dir_base = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(export_dir_base, export_input_fn,
+ assets_extra=assets_extra)
+
+ self.assertTrue(tf.gfile.Exists(export_dir_base))
+ self.assertTrue(tf.gfile.Exists(export_dir))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('saved_model.pb'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('variables'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(tf.gfile.Exists(os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets'))))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))))
+ self.assertEqual(
+ compat.as_bytes(VOCAB_FILE_CONTENT),
+ compat.as_bytes(tf.gfile.GFile(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))).read()))
+
+ expected_extra_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
+ self.assertTrue(tf.gfile.Exists(
+ os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra'))))
+ self.assertTrue(tf.gfile.Exists(expected_extra_path))
+ self.assertEqual(
+ compat.as_bytes(EXTRA_FILE_CONTENT),
+ compat.as_bytes(tf.gfile.GFile(expected_extra_path).read()))
+
+ expected_vocab_file = os.path.join(compat.as_bytes(tmpdir),
+ compat.as_bytes('my_vocab_file'))
+ # Restore, to validate that the export was well-formed.
+ with tf.Graph().as_default() as graph:
+ with tf.Session(graph=graph) as sess:
+ loader.load(sess, [tag_constants.SERVING], export_dir)
+ assets = [x.eval()
+ for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)]
+ self.assertItemsEqual([expected_vocab_file], assets)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('linear/linear/feature/matmul' in graph_ops)
+
+ # cleanup
+ tf.gfile.DeleteRecursively(tmpdir)
+
class InferRealValuedColumnsTest(tf.test.TestCase):
diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py
index 2b713efb85..8e0a4de19c 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/head.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/head.py
@@ -24,6 +24,7 @@ import six
from tensorflow.contrib import losses
from tensorflow.contrib import metrics as metrics_lib
from tensorflow.contrib.learn.python.learn import metric_spec
+from tensorflow.contrib.learn.python.learn.estimators import constants
from tensorflow.contrib.learn.python.learn.estimators import estimator
from tensorflow.contrib.learn.python.learn.estimators import metric_key
from tensorflow.contrib.learn.python.learn.estimators import model_fn
@@ -200,6 +201,9 @@ class _Head(object):
"""
__metaclass__ = abc.ABCMeta
+ def __init__(self, head_name):
+ self._head_name = head_name
+
@abc.abstractproperty
def logits_dimension(self):
raise NotImplementedError("Calling an abstract method.")
@@ -227,6 +231,25 @@ class _Head(object):
"""
raise NotImplementedError("Calling an abstract method.")
+ def _create_output_alternatives(self, predictions):
+ """Creates output alternative for the Head.
+
+ Args:
+ predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a
+ symbolic name for an output Tensor possibly but not necessarily taken from
+ `PredictionKey`, and 'Tensor' is the corresponding output Tensor itself.
+
+ Returns:
+ a dict of {submodel_name: (problem_type, {tensor_name: Tensor})}, where
+ 'submodel_name' is a submodel identifier that should be consistent across
+ the pipeline (here likely taken from the head_name),
+ 'problem_type' is a `ProblemType`,
+ 'tensor_name' is a symbolic name for an output Tensor possibly but not
+ necessarily taken from `PredictionKey`, and
+ 'Tensor' is the corresponding output Tensor itself.
+ """
+ return {self._head_name: (self._problem_type, predictions)}
+
class _RegressionHead(_Head):
"""_Head for regression."""
@@ -249,14 +272,16 @@ class _RegressionHead(_Head):
head_name: name of the head. If provided, predictions, summary and metrics
keys will be prefixed by the head_name and an underscore.
"""
+ super(_RegressionHead, self).__init__(head_name=head_name)
+
self._loss_fn = loss_fn
self._logits_dimension = label_dimension
self._label_name = label_name
self._weight_column_name = weight_column_name
- self._head_name = head_name
self._enable_centered_bias = enable_centered_bias
self._centered_bias_weight_collection = _head_prefixed(head_name,
"centered_bias")
+ self._problem_type = constants.ProblemType.LINEAR_REGRESSION
@property
def logits_dimension(self):
@@ -289,7 +314,8 @@ class _RegressionHead(_Head):
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=self._signature_fn())
+ signature_fn=self._signature_fn(),
+ output_alternatives=self._create_output_alternatives(predictions))
def _training_loss(self, features, labels, logits, name=None):
"""Returns training loss tensor for this head.
@@ -450,6 +476,8 @@ class _BinaryLogisticHead(_Head):
Raises:
ValueError: if n_classes is invalid.
"""
+ super(_BinaryLogisticHead, self).__init__(head_name=head_name)
+
self._thresholds = thresholds if thresholds else [.5]
self._label_name = label_name
self._weight_column_name = weight_column_name
@@ -705,6 +733,8 @@ class _MultiClassHead(_Head):
Raises:
ValueError: if n_classes is invalid.
"""
+ super(_MultiClassHead, self).__init__(head_name=head_name)
+
if (n_classes is None) or (n_classes <= 2):
raise ValueError("n_classes must be > 2: %s." % n_classes)
self._thresholds = thresholds if thresholds else [.5]
@@ -712,11 +742,11 @@ class _MultiClassHead(_Head):
self._logits_dimension = n_classes
self._label_name = label_name
self._weight_column_name = weight_column_name
- self._head_name = head_name
self._loss_fn = loss_fn
self._enable_centered_bias = enable_centered_bias
self._centered_bias_weight_collection = _head_prefixed(head_name,
"centered_bias")
+ self._problem_type = constants.ProblemType.CLASSIFICATION
@property
def logits_dimension(self):
@@ -749,7 +779,8 @@ class _MultiClassHead(_Head):
loss=loss,
train_op=train_op,
eval_metric_ops=eval_metric_ops,
- signature_fn=self._signature_fn())
+ signature_fn=self._signature_fn(),
+ output_alternatives=self._create_output_alternatives(predictions))
def _training_loss(self, features, labels, logits=None, name=None):
"""Returns training loss tensor for this head.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py
index 763d70ddaf..7e0b9c36f6 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/linear.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py
@@ -27,6 +27,7 @@ import six
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
@@ -519,6 +520,22 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
@deprecated("2016-10-30",
"This method will be removed after the deprecation date. "
@@ -760,6 +777,22 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
@deprecated("2016-10-30",
"This method will be removed after the deprecation date. "
diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
index 3f9351ce22..42f21bd196 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py
@@ -49,13 +49,33 @@ class ModeKeys(object):
# TODO(roumposg): Pass output_signature_fn instead of signature_fn.
class ModelFnOps(collections.namedtuple(
'ModelFnOps',
- ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])):
+ ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn',
+ 'output_alternatives'])):
"""Ops returned from a model_fn."""
+ # TODO(soergel): remove signature_fn once sessionbundle export is deprecated.
+
def __new__(cls, mode, predictions=None, loss=None, train_op=None,
- eval_metric_ops=None, signature_fn=None):
+ eval_metric_ops=None, signature_fn=None,
+ output_alternatives=None):
"""Creates a validated `ModelFnOps` instance.
+ For a multi-headed model, the predictions dict here will contain the outputs
+ of all of the heads. However: at serving time, requests will be made
+ specifically for one or more heads, and the RPCs used for these requests may
+ differ by problem type (i.e., regression, classification, other). The
+ purpose of the output_alternatives dict is to aid in exporting a SavedModel
+ from which such head-specific queries can be served. These
+ output_alternatives will be combined with input_alternatives (see
+ `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying
+ the valid requests that can be served from this model.
+
+ For a single-headed model, it is still adviseable to provide
+ output_alternatives with a single entry, because this is how the problem
+ type is communicated for export and serving. If output_alternatives is not
+ given, the resulting SavedModel will support only one head of unspecified
+ type.
+
Args:
mode: One of `ModeKeys`. Specifies if this training, evaluation or
prediction.
@@ -65,6 +85,14 @@ class ModelFnOps(collections.namedtuple(
eval_metric_ops: Dict of metric results keyed by name. The values of the
dict are the results of calling a metric function, such as `Tensor`.
signature_fn: The signature_fn used for exporting.
+ output_alternatives: a dict of
+ `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where
+ `submodel_name` is a submodel identifier that should be consistent
+ across the pipeline (here likely taken from the name of each `Head`,
+ for models that use them), `problem_type` is a `ProblemType`,
+ `tensor_name` is a symbolic name for an output Tensor possibly but not
+ necessarily taken from `PredictionKey`, and `Tensor` is the
+ corresponding output Tensor itself.
Returns:
A validated `ModelFnOps` object.
@@ -122,4 +150,5 @@ class ModelFnOps(collections.namedtuple(
raise ValueError('signature_fn is not callable.')
return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op,
- eval_metric_ops, signature_fn)
+ eval_metric_ops, signature_fn,
+ output_alternatives)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
index c2c41255c9..deb55efc9f 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.framework import deprecated_arg_values
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
@@ -352,3 +353,19 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable):
self._estimator._model_fn = orig_model_fn
# pylint: enable=protected-access
return result
+
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py
index eeee673c5a..a6e4e7b6a3 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/svm.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py
@@ -19,12 +19,14 @@ from __future__ import division
from __future__ import print_function
import inspect
+import re
import tempfile
from tensorflow.contrib import layers
from tensorflow.contrib.framework import deprecated_arg_values
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
+from tensorflow.contrib.framework.python.framework import experimental
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import trainable
from tensorflow.contrib.learn.python.learn.estimators import estimator
@@ -235,6 +237,22 @@ class SVM(trainable.Trainable, evaluable.Evaluable):
default_batch_size=default_batch_size,
exports_to_keep=exports_to_keep)
+ @experimental
+ def export_savedmodel(self,
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=None,
+ assets_extra=None,
+ as_text=False,
+ exports_to_keep=None):
+ return self._estimator.export_savedmodel(
+ export_dir_base,
+ input_fn,
+ default_output_alternative_key=default_output_alternative_key,
+ assets_extra=assets_extra,
+ as_text=as_text,
+ exports_to_keep=exports_to_keep)
+
@property
def weights_(self):
values = {}
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py
new file mode 100644
index 0000000000..dd4376f051
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/gc.py
@@ -0,0 +1,205 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+r"""System for specifying garbage collection (GC) of path based data.
+
+This framework allows for GC of data specified by path names, for example files
+on disk. gc.Path objects each represent a single item stored at a path and may
+be a base directory,
+ /tmp/exports/0/...
+ /tmp/exports/1/...
+ ...
+or a fully qualified file,
+ /tmp/train-1.ckpt
+ /tmp/train-2.ckpt
+ ...
+
+A gc filter function takes and returns a list of gc.Path items. Filter
+functions are responsible for selecting Path items for preservation or deletion.
+Note that functions should always return a sorted list.
+
+For example,
+ base_dir = "/tmp"
+ # create the directories
+ for e in xrange(10):
+ os.mkdir("%s/%d" % (base_dir, e), 0o755)
+
+ # create a simple parser that pulls the export_version from the directory
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ path_list = gc.get_paths("/tmp", parser) # contains all ten Paths
+
+ every_fifth = gc.mod_export_version(5)
+ print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"]
+
+ largest_three = gc.largest_export_versions(3)
+ print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
+
+ both = gc.union(every_fifth, largest_three)
+ print both(all_paths) # shows ["/tmp/0", "/tmp/5",
+ # "/tmp/7", "/tmp/8", "/tmp/9"]
+ # delete everything not in 'both'
+ to_delete = gc.negation(both)
+ for p in to_delete(all_paths):
+ gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
+ # "/tmp/3", "/tmp/4", "/tmp/6",
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import heapq
+import math
+import os
+
+from tensorflow.python.platform import gfile
+
+Path = collections.namedtuple('Path', 'path export_version')
+
+
+def largest_export_versions(n):
+ """Creates a filter that keeps the largest n export versions.
+
+ Args:
+ n: number of versions to keep.
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ heap = []
+ for idx, path in enumerate(paths):
+ if path.export_version is not None:
+ heapq.heappush(heap, (path.export_version, idx))
+ keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
+ return sorted(keepers)
+
+ return keep
+
+
+def one_of_every_n_export_versions(n):
+ """Creates a filter that keeps one of every n export versions.
+
+ Args:
+ n: interval size.
+
+ Returns:
+ A filter function that keeps exactly one path from each interval
+ [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
+ interval the largest is kept.
+ """
+ def keep(paths):
+ """A filter function that keeps exactly one out of every n paths."""
+
+ keeper_map = {} # map from interval to largest path seen in that interval
+ for p in paths:
+ if p.export_version is None:
+ # Skip missing export_versions.
+ continue
+ # Find the interval (with a special case to map export_version = 0 to
+ # interval 0.
+ interval = math.floor(
+ (p.export_version - 1) / n) if p.export_version else 0
+ existing = keeper_map.get(interval, None)
+ if (not existing) or (existing.export_version < p.export_version):
+ keeper_map[interval] = p
+ return sorted(keeper_map.values())
+
+ return keep
+
+
+def mod_export_version(n):
+ """Creates a filter that keeps every export that is a multiple of n.
+
+ Args:
+ n: step size.
+
+ Returns:
+ A filter function that keeps paths where export_version % n == 0.
+ """
+ def keep(paths):
+ keepers = []
+ for p in paths:
+ if p.export_version % n == 0:
+ keepers.append(p)
+ return sorted(keepers)
+ return keep
+
+
+def union(lf, rf):
+ """Creates a filter that keeps the union of two filters.
+
+ Args:
+ lf: first filter
+ rf: second filter
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ l = set(lf(paths))
+ r = set(rf(paths))
+ return sorted(list(l|r))
+ return keep
+
+
+def negation(f):
+ """Negate a filter.
+
+ Args:
+ f: filter function to invert
+
+ Returns:
+ A filter function that returns the negation of f.
+ """
+ def keep(paths):
+ l = set(paths)
+ r = set(f(paths))
+ return sorted(list(l-r))
+ return keep
+
+
+def get_paths(base_dir, parser):
+ """Gets a list of Paths in a given directory.
+
+ Args:
+ base_dir: directory.
+ parser: a function which gets the raw Path and can augment it with
+ information such as the export_version, or ignore the path by returning
+ None. An example parser may extract the export version from a path
+ such as "/tmp/exports/100" an another may extract from a full file
+ name such as "/tmp/checkpoint-99.out".
+
+ Returns:
+ A list of Paths contained in the base directory with the parsing function
+ applied.
+ By default the following fields are populated,
+ - Path.path
+ The parsing function is responsible for populating,
+ - Path.export_version
+ """
+ raw_paths = gfile.ListDirectory(base_dir)
+ paths = []
+ for r in raw_paths:
+ p = parser(Path(os.path.join(base_dir, r), None))
+ if p:
+ paths.append(p)
+ return sorted(paths)
diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
new file mode 100644
index 0000000000..dbe3304f21
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py
@@ -0,0 +1,120 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for learn.utils.gc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python.learn.utils import gc
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+
+
+def tearDownModule():
+ gfile.DeleteRecursively(tf.test.get_temp_dir())
+
+
+class GcTest(test_util.TensorFlowTestCase):
+
+ def testLargestExportVersions(self):
+ paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
+ newest = gc.largest_export_versions(2)
+ n = newest(paths)
+ self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
+
+ def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
+ newest = gc.largest_export_versions(2)
+ n = newest(paths)
+ self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
+
+ def testModExportVersion(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.mod_export_version(2)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
+ mod = gc.mod_export_version(3)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
+
+ def testOneOfEveryNExportVersions(self):
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
+ gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 3), gc.Path("/foo", 6),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)])
+
+ def testOneOfEveryNExportVersionsZero(self):
+ # Zero is a special case since it gets rolled into the first interval.
+ # Test that here.
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 0), gc.Path("/foo", 5)])
+
+ def testUnion(self):
+ paths = []
+ for i in xrange(10):
+ paths.append(gc.Path("/foo", i))
+ f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
+ self.assertEquals(
+ f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
+ gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 9)])
+
+ def testNegation(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.negation(gc.mod_export_version(2))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
+ mod = gc.negation(gc.mod_export_version(3))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
+
+ def testPathsWithParse(self):
+ base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
+ self.assertFalse(gfile.Exists(base_dir))
+ for p in xrange(3):
+ gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
+ # add a base_directory to ignore
+ gfile.MakeDirs(os.path.join(base_dir, "ignore"))
+
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ self.assertEquals(
+ gc.get_paths(base_dir, parser=parser),
+ [gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1),
+ gc.Path(os.path.join(base_dir, "2"), 2)])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py
new file mode 100644
index 0000000000..2cb7173d5a
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py
@@ -0,0 +1,97 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities for creating input_fns."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+
+
+# A return type allowing input_fns to return multiple values in a well-
+# defined way (analogous to ModelFnOps).
+# The expected return values are:
+# features: a dict of string to Tensor, giving the features to be passed to
+# the model.
+# labels: a dict of string to Tensor, giving labels (aka targets) for training.
+# default_inputs: a dict of string to Tensor, giving the input Tensors (if
+# any) that this input_fn expects to be fed.
+InputFnOps = collections.namedtuple('InputFnOps',
+ ['features',
+ 'labels',
+ 'default_inputs'])
+
+
+def build_parsing_serving_input_fn(feature_spec, default_batch_size=1):
+ """Build an input_fn appropriate for serving, expecting fed tf.Examples.
+
+ Creates an input_fn that expects a serialized tf.Example fed into a string
+ placeholder. The function parses the tf.Example according to the provided
+ feature_spec, and returns all parsed Tensors as features. This input_fn is
+ for use at serving time, so the labels return value is always None.
+
+ Args:
+ feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`.
+ default_batch_size: the number of query examples expected per batch.
+
+ Returns:
+ An input_fn suitable for use in serving.
+ """
+ def input_fn():
+ """An input_fn that expects a serialized tf.Example."""
+ serialized_tf_example = array_ops.placeholder(dtype=dtypes.string,
+ shape=[default_batch_size],
+ name='input_example_tensor')
+ inputs = {'examples': serialized_tf_example}
+ features = parsing_ops.parse_example(serialized_tf_example, feature_spec)
+ labels = None # these are not known in serving!
+ return InputFnOps(features, labels, inputs)
+ return input_fn
+
+
+def build_default_serving_input_fn(features, default_batch_size=1):
+ """Build an input_fn appropriate for serving, expecting feature Tensors.
+
+ Creates an input_fn that expects all features to be fed directly.
+ This input_fn is for use at serving time, so the labels return value is always
+ None.
+
+ Args:
+ features: a dict of string to `Tensor`.
+ default_batch_size: the number of query examples expected per batch.
+
+ Returns:
+ An input_fn suitable for use in serving.
+ """
+ def input_fn():
+ """an input_fn that expects all features to be fed directly."""
+ features_placeholders = {}
+ for name, t in features.items():
+ shape_list = t.get_shape().as_list()
+ shape_list[0] = default_batch_size
+ shape = tensor_shape.TensorShape(shape_list)
+
+ features_placeholders[name] = array_ops.placeholder(dtype=t.dtype,
+ shape=shape,
+ name=t.name)
+ labels = None # these are not known in serving!
+ return InputFnOps(features_placeholders, labels, features_placeholders)
+ return input_fn
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
new file mode 100644
index 0000000000..54bb0fb3d7
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -0,0 +1,248 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utilities supporting export to SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import os
+import re
+import time
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import prediction_key
+from tensorflow.contrib.learn.python.learn.utils import gc
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.python.platform import gfile
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+
+from tensorflow.python.util import compat
+
+# A key for use in the input_alternatives dict indicating the default input.
+# This is the input that will be expected when a serving request does not
+# specify a specific signature.
+# The default input alternative specifies placeholders that the input_fn
+# requires to be fed (in the typical case, a single placeholder for a
+# serialized tf.Example).
+DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative'
+
+# A key for use in the input_alternatives dict indicating the features input.
+# The features inputs alternative specifies the feature Tensors provided as
+# input to the model_fn, i.e. the outputs of the input_fn.
+FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative'
+
+# A key for use in the output_alternatives dict indicating the default output.
+# This is the output that will be provided when a serving request does not
+# specify a specific signature.
+# In a single-headed model, the single output is automatically the default.
+# In a multi-headed model, the name of the desired default head should be
+# provided to get_output_alternatives.
+DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative'
+
+
+def build_standardized_signature_def(
+ input_tensors, output_tensors, problem_type):
+ """Build a SignatureDef using problem type and input and output Tensors.
+
+ Note that this delegates the actual creation of the signatures to methods in
+ //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may
+ assign names to the input and output tensors (depending on the problem type)
+ that are standardized in the context of SavedModel.
+
+ Args:
+ input_tensors: a dict of string key to `Tensor`
+ output_tensors: a dict of string key to `Tensor`
+ problem_type: an instance of constants.ProblemType, specifying
+ classification, regression, etc.
+
+ Returns:
+ A SignatureDef using SavedModel standard keys where possible.
+
+ Raises:
+ ValueError: if input_tensors or output_tensors is None or empty.
+ """
+
+ if not input_tensors:
+ raise ValueError('input_tensors must be provided.')
+ if not output_tensors:
+ raise ValueError('output_tensors must be provided.')
+
+ # Per-method signature_def functions will standardize the keys if possible
+ if _is_classification_problem(problem_type, input_tensors, output_tensors):
+ (_, examples), = input_tensors.items()
+ classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
+ scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
+ if not (classes or scores):
+ (_, classes), = output_tensors.items()
+ return signature_def_utils.classification_signature_def(
+ examples, classes, scores)
+ elif _is_regression_problem(problem_type, input_tensors, output_tensors):
+ (_, examples), = input_tensors.items()
+ (_, predictions), = output_tensors.items()
+ return signature_def_utils.regression_signature_def(examples, predictions)
+ else:
+ return signature_def_utils.predict_signature_def(
+ input_tensors, output_tensors)
+
+
+def _is_classification_problem(problem_type, input_tensors, output_tensors):
+ classes = output_tensors.get(prediction_key.PredictionKey.CLASSES)
+ scores = output_tensors.get(prediction_key.PredictionKey.SCORES)
+ return ((problem_type == constants.ProblemType.CLASSIFICATION or
+ problem_type == constants.ProblemType.LOGISTIC_REGRESSION)
+ and len(input_tensors) == 1
+ and (classes or scores or len(output_tensors) == 1))
+
+
+def _is_regression_problem(problem_type, input_tensors, output_tensors):
+ return (problem_type == constants.ProblemType.LINEAR_REGRESSION
+ and len(input_tensors) == 1
+ and len(output_tensors) == 1)
+
+
+def get_input_alternatives(input_ops):
+ """Obtain all input alternatives using the input_fn output and heuristics."""
+ input_alternatives = {}
+ if isinstance(input_ops, input_fn_utils.InputFnOps):
+ features, unused_labels, default_inputs = input_ops
+ input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs
+ else:
+ features, unused_labels = input_ops
+
+ if not features:
+ raise ValueError('Features must be defined.')
+
+ # Add the "features" input_signature in any case.
+ # Note defensive copy because model_fns alter the features dict.
+ input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = (
+ copy.copy(features))
+
+ return input_alternatives, features
+
+
+def get_output_alternatives(
+ model_fn_ops,
+ default_output_alternative_key=DEFAULT_OUTPUT_ALTERNATIVE_KEY):
+ """Obtain all output alternatives using the model_fn output and heuristics."""
+ output_alternatives = model_fn_ops.output_alternatives
+
+ # Identify the default outputs, creating them if needed.
+ if (output_alternatives
+ and default_output_alternative_key not in output_alternatives):
+ raise ValueError('default_output_alternative_key not in '
+ 'output_alternatives: %s' % default_output_alternative_key)
+
+ if (output_alternatives
+ and default_output_alternative_key in output_alternatives):
+ # If a default head is provided, use it.
+ actual_default_output_alternative_key = default_output_alternative_key
+ return output_alternatives, actual_default_output_alternative_key
+
+ if output_alternatives and len(output_alternatives) == 1:
+ # If there is only one head, use it as the default.
+ (actual_default_output_alternative_key, _), = output_alternatives.items()
+ return output_alternatives, actual_default_output_alternative_key
+
+ # Lacking provided output alternatives, the best we can do is to
+ # interpret the model as single-headed of unknown type.
+ default_problem_type = constants.ProblemType.UNSPECIFIED
+ default_outputs = model_fn_ops.predictions
+ actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY
+ output_alternatives = {actual_default_output_alternative_key:
+ (default_problem_type, default_outputs)}
+ return output_alternatives, actual_default_output_alternative_key
+
+
+def build_all_signature_defs(input_alternatives, output_alternatives,
+ actual_default_output_alternative_key):
+ """Build `SignatureDef`s from all pairs of input and output alternatives."""
+
+ signature_def_map = {
+ ('%s:%s' % (input_key, output_key or 'None')):
+ build_standardized_signature_def(
+ inputs, outputs, problem_type)
+ for input_key, inputs in input_alternatives.items()
+ for output_key, (problem_type, outputs)
+ in output_alternatives.items()}
+
+ # Add the default SignatureDef
+ default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY]
+ if not default_inputs:
+ default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY]
+ # default outputs are guaranteed to exist above
+ (default_problem_type, default_outputs) = (
+ output_alternatives[actual_default_output_alternative_key])
+ signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = (
+ build_standardized_signature_def(
+ default_inputs, default_outputs, default_problem_type))
+
+ return signature_def_map
+
+
+def get_timestamped_export_dir(export_dir_base):
+ """Builds a path to a new subdirectory within the base directory.
+
+ Each export is written into a new subdirectory named using the
+ current time. This guarantees monotonically increasing version
+ numbers even across multiple runs of the pipeline.
+ The timestamp used is the number of milliseconds since epoch UTC.
+
+ Args:
+ export_dir_base: A string containing a directory to write the exported
+ graph and checkpoints.
+ Returns:
+ The full path of the new subdirectory (which is not actually created yet).
+ """
+ export_timestamp = int(time.time() * 1e3)
+
+ export_dir = os.path.join(
+ compat.as_bytes(export_dir_base),
+ compat.as_bytes(str(export_timestamp)))
+ return export_dir
+
+
+def garbage_collect_exports(export_dir_base, exports_to_keep):
+ """Deletes older exports, retaining only a given number of the most recent.
+
+ Export subdirectories are assumed to be named with monotonically increasing
+ integers; the most recent are taken to be those with the largest values.
+
+ Args:
+ export_dir_base: the base directory under which each export is in a
+ versioned subdirectory.
+ exports_to_keep: the number of recent exports to retain.
+ """
+ if exports_to_keep is None:
+ return
+
+ keep_filter = gc.largest_export_versions(exports_to_keep)
+ delete_filter = gc.negation(keep_filter)
+
+ # Export dir must not end with / or it will break the re match below.
+ if export_dir_base.endswith('/'):
+ export_dir_base = export_dir_base[:-1]
+
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match('^' + export_dir_base + '/(\\d{13})$', path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)):
+ gfile.DeleteRecursively(p.path)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
new file mode 100644
index 0000000000..538e0ab104
--- /dev/null
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
@@ -0,0 +1,228 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests of utilities supporting export to SavedModel."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+import time
+
+import tensorflow as tf
+
+from tensorflow.contrib.learn.python.learn.estimators import constants
+from tensorflow.contrib.learn.python.learn.estimators import model_fn
+from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
+from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.core.framework import tensor_shape_pb2
+from tensorflow.core.framework import types_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+
+
+class SavedModelExportUtilsTest(tf.test.TestCase):
+
+ def test_build_standardized_signature_def(self):
+ input_tensors = {
+ "input-1": tf.placeholder(tf.float32, 1, name="input-tensor-1")}
+ output_tensors = {
+ "output-1": tf.placeholder(tf.float32, 1, name="output-tensor-1")}
+ problem_type = constants.ProblemType.LINEAR_REGRESSION
+ regression_signature_def = (
+ saved_model_export_utils.build_standardized_signature_def(
+ input_tensors, output_tensors, problem_type))
+ expected_regression_signature_def = meta_graph_pb2.SignatureDef()
+ shape = tensor_shape_pb2.TensorShapeProto(
+ dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)])
+ dtype = types_pb2.DataType.Value("DT_FLOAT")
+ expected_regression_signature_def.inputs[
+ signature_constants.REGRESS_INPUTS].CopyFrom(
+ meta_graph_pb2.TensorInfo(name="input-tensor-1:0",
+ dtype=dtype,
+ tensor_shape=shape))
+ expected_regression_signature_def.outputs[
+ signature_constants.REGRESS_OUTPUTS].CopyFrom(
+ meta_graph_pb2.TensorInfo(name="output-tensor-1:0",
+ dtype=dtype,
+ tensor_shape=shape))
+
+ expected_regression_signature_def.method_name = (
+ signature_constants.REGRESS_METHOD_NAME)
+ self.assertEqual(regression_signature_def,
+ expected_regression_signature_def)
+
+ def test_get_input_alternatives(self):
+ input_ops = input_fn_utils.InputFnOps("bogus features dict", None,
+ "bogus default input dict")
+
+ input_alternatives, _ = saved_model_export_utils.get_input_alternatives(
+ input_ops)
+ self.assertEqual(
+ input_alternatives[
+ saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY],
+ "bogus default input dict")
+ self.assertEqual(
+ input_alternatives[
+ saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY],
+ "bogus features dict")
+
+ def test_get_output_alternatives_explicit(self):
+ provided_output_alternatives = {
+ "head-1": (constants.ProblemType.LINEAR_REGRESSION,
+ "bogus output dict"),
+ "head-2": (constants.ProblemType.CLASSIFICATION,
+ "bogus output dict 2"),
+ "head-3": (constants.ProblemType.UNSPECIFIED,
+ "bogus output dict 3"),
+ }
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": "bogus_tensor"},
+ output_alternatives=provided_output_alternatives)
+ output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, "head-1")
+
+ self.assertEqual(provided_output_alternatives, output_alternatives)
+
+ def test_get_output_alternatives_implicit(self):
+ prediction_tensor = tf.constant(["bogus"])
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": prediction_tensor},
+ output_alternatives=None)
+
+ output_alternatives, _ = saved_model_export_utils.get_output_alternatives(
+ model_fn_ops, "some_output")
+ self.assertEqual(
+ {"default_output_alternative": (constants.ProblemType.UNSPECIFIED,
+ {"some_output": prediction_tensor})},
+ output_alternatives)
+
+ def test_build_all_signature_defs(self):
+ input_features = tf.constant(["10"])
+ input_example = tf.constant(["11"])
+ input_ops = input_fn_utils.InputFnOps(
+ {"features": input_features},
+ None,
+ {"default input": input_example})
+ input_alternatives, _ = (
+ saved_model_export_utils.get_input_alternatives(input_ops))
+ output_1 = tf.constant(["1"])
+ output_2 = tf.constant(["2"])
+ output_3 = tf.constant(["3"])
+ provided_output_alternatives = {
+ "head-1": (constants.ProblemType.LINEAR_REGRESSION,
+ {"some_output_1": output_1}),
+ "head-2": (constants.ProblemType.CLASSIFICATION,
+ {"some_output_2": output_2}),
+ "head-3": (constants.ProblemType.UNSPECIFIED,
+ {"some_output_3": output_3}),
+ }
+ model_fn_ops = model_fn.ModelFnOps(
+ model_fn.ModeKeys.INFER,
+ predictions={"some_output": tf.constant(["4"])},
+ output_alternatives=provided_output_alternatives)
+ output_alternatives, _ = (
+ saved_model_export_utils.get_output_alternatives(model_fn_ops,
+ "head-1"))
+
+ signature_defs = saved_model_export_utils.build_all_signature_defs(
+ input_alternatives, output_alternatives, "head-1")
+
+ expected_signature_defs = {
+ "serving_default":
+ signature_def_utils.regression_signature_def(
+ input_example, output_1),
+ "default_input_alternative:head-1":
+ signature_def_utils.regression_signature_def(
+ input_example, output_1),
+ "default_input_alternative:head-2":
+ signature_def_utils.classification_signature_def(
+ input_example, output_2, None),
+ "default_input_alternative:head-3":
+ signature_def_utils.predict_signature_def(
+ {"input": input_example}, {"output": output_3}),
+ "features_input_alternative:head-1":
+ signature_def_utils.regression_signature_def(
+ input_features, output_1),
+ "features_input_alternative:head-2":
+ signature_def_utils.classification_signature_def(
+ input_features, output_2, None),
+ "features_input_alternative:head-3":
+ signature_def_utils.predict_signature_def(
+ {"input": input_features}, {"output": output_3}),
+ }
+
+ self.assertDictEqual(expected_signature_defs, signature_defs)
+
+ def test_get_timestamped_export_dir(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ export_dir_1 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ time.sleep(0.001)
+ export_dir_2 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ time.sleep(0.001)
+ export_dir_3 = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+
+ # Export directories should be named using a timestamp that is milliseconds
+ # since epoch. Such a timestamp is 13 digits long.
+ time_1 = os.path.basename(export_dir_1)
+ self.assertEqual(13, len(time_1))
+ time_2 = os.path.basename(export_dir_2)
+ self.assertEqual(13, len(time_2))
+ time_3 = os.path.basename(export_dir_3)
+ self.assertEqual(13, len(time_3))
+
+ self.assertTrue(int(time_1) < int(time_2))
+ self.assertTrue(int(time_2) < int(time_3))
+
+ def test_garbage_collect_exports(self):
+ export_dir_base = tempfile.mkdtemp() + "export/"
+ tf.gfile.MkDir(export_dir_base)
+ export_dir_1 = _create_test_export_dir(export_dir_base)
+ export_dir_2 = _create_test_export_dir(export_dir_base)
+ export_dir_3 = _create_test_export_dir(export_dir_base)
+ export_dir_4 = _create_test_export_dir(export_dir_base)
+
+ self.assertTrue(tf.gfile.Exists(export_dir_1))
+ self.assertTrue(tf.gfile.Exists(export_dir_2))
+ self.assertTrue(tf.gfile.Exists(export_dir_3))
+ self.assertTrue(tf.gfile.Exists(export_dir_4))
+
+ # Garbage collect all but the most recent 2 exports,
+ # where recency is determined based on the timestamp directory names.
+ saved_model_export_utils.garbage_collect_exports(export_dir_base, 2)
+
+ self.assertFalse(tf.gfile.Exists(export_dir_1))
+ self.assertFalse(tf.gfile.Exists(export_dir_2))
+ self.assertTrue(tf.gfile.Exists(export_dir_3))
+ self.assertTrue(tf.gfile.Exists(export_dir_4))
+
+
+def _create_test_export_dir(export_dir_base):
+ export_dir = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ tf.gfile.MkDir(export_dir)
+ time.sleep(0.001)
+ return export_dir
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index db5768acb8..1663a1f251 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -2,7 +2,10 @@
# TensorFlow SavedModel.
package(
- default_visibility = ["//tensorflow/python/saved_model:__subpackages__"],
+ default_visibility = [
+ "//tensorflow/contrib/learn:__subpackages__",
+ "//tensorflow/python/saved_model:__subpackages__",
+ ],
)
licenses(["notice"]) # Apache 2.0
@@ -33,7 +36,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
- "//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:platform",
@@ -48,7 +50,6 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
- "//tensorflow:tensorflow_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:platform",
"//tensorflow/python:training",
@@ -94,7 +95,9 @@ py_library(
name = "utils",
srcs = ["utils.py"],
srcs_version = "PY2AND3",
- deps = ["//tensorflow/core:protos_all_py"],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ ],
)
py_test(
@@ -111,6 +114,31 @@ py_test(
],
)
+py_library(
+ name = "signature_def_utils",
+ srcs = ["signature_def_utils.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":signature_constants",
+ ":utils",
+ "//tensorflow/core:protos_all_py",
+ ],
+)
+
+py_test(
+ name = "signature_def_utils_test",
+ size = "small",
+ srcs = [
+ "signature_def_utils_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":signature_def_utils",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
# -----------------------------------------------------------------------------
# Google-internal targets. These must be at the end for syncrepo.
diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
index 0a05ff09ca..7927dfe632 100644
--- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
+++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py
@@ -36,8 +36,8 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two",
@@ -121,7 +121,7 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False):
output_tensor = meta_graph_pb2.TensorInfo()
output_tensor.name = tf.identity(y).name
signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor}
- signature_def = utils.build_signature_def(
+ signature_def = signature_def_utils.build_signature_def(
signature_inputs, signature_outputs,
signature_constants.REGRESS_METHOD_NAME)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 0f8ddfc65b..bf5b186b80 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -27,8 +27,8 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
-from tensorflow.python.saved_model import utils
from tensorflow.python.util import compat
@@ -315,7 +315,8 @@ class SavedModelTest(tf.test.TestCase):
with self.test_session(graph=tf.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 42)
# Build and populate an empty SignatureDef for testing.
- foo_signature = utils.build_signature_def(dict(), dict(), "foo")
+ foo_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "foo")
builder.add_meta_graph_and_variables(
sess, ["foo"], signature_def_map={"foo_key": foo_signature})
@@ -324,10 +325,12 @@ class SavedModelTest(tf.test.TestCase):
with self.test_session(graph=tf.Graph()) as sess:
self._init_and_validate_variable(sess, "v", 43)
# Build and populate a different SignatureDef for testing.
- bar_signature = utils.build_signature_def(dict(), dict(), "bar")
+ bar_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "bar")
# Also, build a different SignatureDef corresponding to "foo_key" defined
# in the previous graph.
- foo_new_signature = utils.build_signature_def(dict(), dict(), "foo_new")
+ foo_new_signature = signature_def_utils.build_signature_def(
+ dict(), dict(), "foo_new")
builder.add_meta_graph(
["bar"],
signature_def_map={"bar_key": bar_signature,
diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py
new file mode 100644
index 0000000000..23e844adb2
--- /dev/null
+++ b/tensorflow/python/saved_model/signature_def_utils.py
@@ -0,0 +1,158 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""SignatureDef utility functions.
+
+Utility functions for constructing SignatureDef protos.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import utils
+
+
+def build_signature_def(inputs=None, outputs=None, method_name=None):
+ """Utility function to build a SignatureDef protocol buffer.
+
+ Args:
+ inputs: Inputs of the SignatureDef defined as a proto map of string to
+ tensor info.
+ outputs: Outputs of the SignatureDef defined as a proto map of string to
+ tensor info.
+ method_name: Method name of the SignatureDef as a string.
+
+ Returns:
+ A SignatureDef protocol buffer constructed based on the supplied arguments.
+ """
+ signature_def = meta_graph_pb2.SignatureDef()
+ if inputs is not None:
+ for item in inputs:
+ signature_def.inputs[item].CopyFrom(inputs[item])
+ if outputs is not None:
+ for item in outputs:
+ signature_def.outputs[item].CopyFrom(outputs[item])
+ if method_name is not None:
+ signature_def.method_name = method_name
+ return signature_def
+
+
+def regression_signature_def(examples, predictions):
+ """Creates regression signature from given examples and predictions.
+
+ Args:
+ examples: `Tensor`.
+ predictions: `Tensor`.
+
+ Returns:
+ A regression-flavored signature_def.
+
+ Raises:
+ ValueError: If examples is `None`.
+ """
+ if examples is None:
+ raise ValueError('examples cannot be None for regression.')
+ if predictions is None:
+ raise ValueError('predictions cannot be None for regression.')
+
+ input_tensor_info = utils.build_tensor_info(examples)
+ signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info}
+
+ output_tensor_info = utils.build_tensor_info(predictions)
+ signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info}
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.REGRESS_METHOD_NAME)
+
+ return signature_def
+
+
+def classification_signature_def(examples, classes, scores):
+ """Creates classification signature from given examples and predictions.
+
+ Args:
+ examples: `Tensor`.
+ classes: `Tensor`.
+ scores: `Tensor`.
+
+ Returns:
+ A classification-flavored signature_def.
+
+ Raises:
+ ValueError: If examples is `None`.
+ """
+ if examples is None:
+ raise ValueError('examples cannot be None for classification.')
+ if classes is None and scores is None:
+ raise ValueError('classes and scores cannot both be None for '
+ 'classification.')
+
+ input_tensor_info = utils.build_tensor_info(examples)
+ signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info}
+
+ signature_outputs = {}
+ if classes is not None:
+ classes_tensor_info = utils.build_tensor_info(classes)
+ signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = (
+ classes_tensor_info)
+ if scores is not None:
+ scores_tensor_info = utils.build_tensor_info(scores)
+ signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = (
+ scores_tensor_info)
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.CLASSIFY_METHOD_NAME)
+
+ return signature_def
+
+
+def predict_signature_def(inputs, outputs):
+ """Creates prediction signature from given inputs and outputs.
+
+ Args:
+ inputs: dict of string to `Tensor`.
+ outputs: dict of string to `Tensor`.
+
+ Returns:
+ A prediction-flavored signature_def.
+
+ Raises:
+ ValueError: If inputs or outputs is `None`.
+ """
+ if inputs is None or not inputs:
+ raise ValueError('inputs cannot be None or empty for prediction.')
+ if outputs is None:
+ raise ValueError('outputs cannot be None or empty for prediction.')
+
+ # If there's only one input or output, we can standardize keys
+ if len(inputs) == 1:
+ (_, value), = inputs.items()
+ inputs = {signature_constants.PREDICT_INPUTS: value}
+ if len(outputs) == 1:
+ (_, value), = outputs.items()
+ outputs = {signature_constants.PREDICT_OUTPUTS: value}
+
+ signature_inputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in inputs.items()}
+ signature_outputs = {key: utils.build_tensor_info(tensor)
+ for key, tensor in outputs.items()}
+
+ signature_def = build_signature_def(
+ signature_inputs, signature_outputs,
+ signature_constants.PREDICT_METHOD_NAME)
+
+ return signature_def
diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py
new file mode 100644
index 0000000000..6dfc4b2cd6
--- /dev/null
+++ b/tensorflow/python/saved_model/signature_def_utils_test.py
@@ -0,0 +1,156 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for SignatureDef utils."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.core.framework import types_pb2
+from tensorflow.python.saved_model import signature_constants
+from tensorflow.python.saved_model import signature_def_utils
+from tensorflow.python.saved_model import utils
+
+
+class SignatureDefUtilsTest(tf.test.TestCase):
+
+ def testBuildSignatureDef(self):
+ x = tf.placeholder(tf.float32, 1, name="x")
+ x_tensor_info = utils.build_tensor_info(x)
+ inputs = dict()
+ inputs["foo-input"] = x_tensor_info
+
+ y = tf.placeholder(tf.float32, name="y")
+ y_tensor_info = utils.build_tensor_info(y)
+ outputs = dict()
+ outputs["foo-output"] = y_tensor_info
+
+ signature_def = signature_def_utils.build_signature_def(
+ inputs, outputs, "foo-method-name")
+ self.assertEqual("foo-method-name", signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = signature_def.inputs["foo-input"]
+ self.assertEqual("x:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
+ self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
+ self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ y_tensor_info_actual = signature_def.outputs["foo-output"]
+ self.assertEqual("y:0", y_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
+ self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
+
+ def testRegressionSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ output1 = tf.constant("b", name="output-1")
+ signature_def = signature_def_utils.regression_signature_def(
+ input1, output1)
+
+ self.assertEqual(signature_constants.REGRESS_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = (
+ signature_def.inputs[signature_constants.REGRESS_INPUTS])
+ self.assertEqual("input-1:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
+ self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(1, len(signature_def.outputs))
+ y_tensor_info_actual = (
+ signature_def.outputs[signature_constants.REGRESS_OUTPUTS])
+ self.assertEqual("output-1:0", y_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype)
+ self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
+
+ def testClassificationSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ output1 = tf.constant("b", name="output-1")
+ output2 = tf.constant("c", name="output-2")
+ signature_def = signature_def_utils.classification_signature_def(
+ input1, output1, output2)
+
+ self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(1, len(signature_def.inputs))
+ x_tensor_info_actual = (
+ signature_def.inputs[signature_constants.CLASSIFY_INPUTS])
+ self.assertEqual("input-1:0", x_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype)
+ self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ classes_tensor_info_actual = (
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES])
+ self.assertEqual("output-1:0", classes_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype)
+ self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim))
+ scores_tensor_info_actual = (
+ signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES])
+ self.assertEqual("output-2:0", scores_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype)
+ self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim))
+
+ def testPredictionSignatureDef(self):
+ input1 = tf.constant("a", name="input-1")
+ input2 = tf.constant("b", name="input-2")
+ output1 = tf.constant("c", name="output-1")
+ output2 = tf.constant("d", name="output-2")
+ signature_def = signature_def_utils.predict_signature_def(
+ {"input-1": input1, "input-2": input2},
+ {"output-1": output1, "output-2": output2})
+
+ self.assertEqual(signature_constants.PREDICT_METHOD_NAME,
+ signature_def.method_name)
+
+ # Check inputs in signature def.
+ self.assertEqual(2, len(signature_def.inputs))
+ input1_tensor_info_actual = (
+ signature_def.inputs["input-1"])
+ self.assertEqual("input-1:0", input1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim))
+ input2_tensor_info_actual = (
+ signature_def.inputs["input-2"])
+ self.assertEqual("input-2:0", input2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim))
+
+ # Check outputs in signature def.
+ self.assertEqual(2, len(signature_def.outputs))
+ output1_tensor_info_actual = (
+ signature_def.outputs["output-1"])
+ self.assertEqual("output-1:0", output1_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype)
+ self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim))
+ output2_tensor_info_actual = (
+ signature_def.outputs["output-2"])
+ self.assertEqual("output-2:0", output2_tensor_info_actual.name)
+ self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype)
+ self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim))
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py
index 550eed0fcc..ecc58fbc7a 100644
--- a/tensorflow/python/saved_model/utils.py
+++ b/tensorflow/python/saved_model/utils.py
@@ -23,6 +23,7 @@ from __future__ import print_function
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
+
# TensorInfo helpers.
@@ -40,30 +41,3 @@ def build_tensor_info(tensor):
name=tensor.name,
dtype=dtype_enum,
tensor_shape=tensor.get_shape().as_proto())
-
-# SignatureDef helpers.
-
-
-def build_signature_def(inputs=None, outputs=None, method_name=None):
- """Utility function to build a SignatureDef protocol buffer.
-
- Args:
- inputs: Inputs of the SignatureDef defined as a proto map of string to
- tensor info.
- outputs: Outputs of the SignatureDef defined as a proto map of string to
- tensor info.
- method_name: Method name of the SignatureDef as a string.
-
- Returns:
- A SignatureDef protocol buffer constructed based on the supplied arguments.
- """
- signature_def = meta_graph_pb2.SignatureDef()
- if inputs is not None:
- for item in inputs:
- signature_def.inputs[item].CopyFrom(inputs[item])
- if outputs is not None:
- for item in outputs:
- signature_def.outputs[item].CopyFrom(outputs[item])
- if method_name is not None:
- signature_def.method_name = method_name
- return signature_def
diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py
index 8ce7d1dea1..74f2624773 100644
--- a/tensorflow/python/saved_model/utils_test.py
+++ b/tensorflow/python/saved_model/utils_test.py
@@ -34,36 +34,6 @@ class UtilsTest(tf.test.TestCase):
self.assertEqual(1, len(x_tensor_info.tensor_shape.dim))
self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size)
- def testBuildSignatureDef(self):
- x = tf.placeholder(tf.float32, 1, name="x")
- x_tensor_info = utils.build_tensor_info(x)
- inputs = dict()
- inputs["foo-input"] = x_tensor_info
-
- y = tf.placeholder(tf.float32, name="y")
- y_tensor_info = utils.build_tensor_info(y)
- outputs = dict()
- outputs["foo-output"] = y_tensor_info
-
- signature_def = utils.build_signature_def(inputs, outputs,
- "foo-method-name")
- self.assertEqual("foo-method-name", signature_def.method_name)
-
- # Check inputs in signature def.
- self.assertEqual(1, len(signature_def.inputs))
- x_tensor_info_actual = signature_def.inputs["foo-input"]
- self.assertEqual("x:0", x_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype)
- self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim))
- self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size)
-
- # Check outputs in signature def.
- self.assertEqual(1, len(signature_def.outputs))
- y_tensor_info_actual = signature_def.outputs["foo-output"]
- self.assertEqual("y:0", y_tensor_info_actual.name)
- self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype)
- self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim))
-
if __name__ == "__main__":
tf.test.main()