diff options
author | David Soergel <soergel@google.com> | 2016-11-29 11:16:34 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-29 11:26:20 -0800 |
commit | e4d9c5f20d28df94cc8d1245c685f03098b811c0 (patch) | |
tree | 1ed707a9485fe9bfe03b9dcba4750cf596c667a9 | |
parent | d753ed4eda5d291e0d88f0b5babf7df4b691a066 (diff) |
Export SavedModel from tf.Learn.
Change: 140502426
23 files changed, 1707 insertions, 77 deletions
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD index 5b7a2d76d8..764971935f 100644 --- a/tensorflow/contrib/learn/BUILD +++ b/tensorflow/contrib/learn/BUILD @@ -21,6 +21,10 @@ py_library( "//tensorflow/contrib/session_bundle:exporter", "//tensorflow/contrib/tensor_forest:client_lib", "//tensorflow/python:framework", + "//tensorflow/python/saved_model:builder", + "//tensorflow/python/saved_model:loader", + "//tensorflow/python/saved_model:signature_def_utils", + "//tensorflow/python/saved_model:tag_constants", ], ) @@ -663,6 +667,32 @@ py_test( ) py_test( + name = "gc_test", + size = "small", + srcs = ["python/learn/utils/gc_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( + name = "saved_model_export_utils_test", + size = "small", + srcs = ["python/learn/utils/saved_model_export_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":learn", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework", + "//tensorflow/python:framework_test_lib", + ], +) + +py_test( name = "stability_test", size = "small", srcs = ["python/learn/estimators/stability_test.py"], diff --git a/tensorflow/contrib/learn/python/learn/estimators/constants.py b/tensorflow/contrib/learn/python/learn/estimators/constants.py new file mode 100644 index 0000000000..aee4541627 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/estimators/constants.py @@ -0,0 +1,26 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Constants regarding Estimators.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + + +class ProblemType(object): + UNSPECIFIED = 0 + CLASSIFICATION = 1 + LINEAR_REGRESSION = 2 + LOGISTIC_REGRESSION = 3 diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn.py b/tensorflow/contrib/learn/python/learn/estimators/dnn.py index b821d842cd..98947cc6d4 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.layers.python.layers import optimizers from tensorflow.contrib.learn.python.learn import evaluable @@ -452,6 +453,22 @@ class DNNClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def model_dir(self): return self._estimator.model_dir diff --git a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py index 5a1bc931f8..256e074079 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py +++ b/tensorflow/contrib/learn/python/learn/estimators/dnn_linear_combined.py @@ -26,6 +26,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.layers.python.layers import feature_column as feature_column_lib from tensorflow.contrib.layers.python.layers import feature_column_ops @@ -868,6 +869,22 @@ class DNNLinearCombinedClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def model_dir(self): return self._estimator.model_dir diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py index 6b5b241936..ea43fba6b3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py @@ -38,6 +38,8 @@ from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework import deprecated_args from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable +from tensorflow.contrib.framework.python.framework import experimental +from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import graph_actions from tensorflow.contrib.learn.python.learn import metric_spec @@ -51,14 +53,21 @@ from tensorflow.contrib.learn.python.learn.estimators import tensor_signature from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedError from tensorflow.contrib.learn.python.learn.learn_io import data_feeder from tensorflow.contrib.learn.python.learn.utils import export - +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.python.client import session as tf_session from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import random_seed from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import data_flow_ops +from tensorflow.python.ops import variables +from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.saved_model import builder as saved_model_builder +from tensorflow.python.saved_model import tag_constants from tensorflow.python.training import device_setter from tensorflow.python.training import saver +from tensorflow.python.util import compat AS_ITERABLE_DATE = '2016-09-15' @@ -553,13 +562,12 @@ class BaseEstimator( use_deprecated_input_fn=use_deprecated_input_fn, default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) - # pylint: enable=protected-access @abc.abstractproperty def _get_train_ops(self, features, labels): """Method that builds model graph and returns trainer ops. - Expected to be overriden by sub-classes that require custom support. + Expected to be overridden by sub-classes that require custom support. Args: features: `Tensor` or `dict` of `Tensor` objects. @@ -1106,6 +1114,106 @@ class Estimator(BaseEstimator): self._labels_info) return self._call_model_fn(features, labels, model_fn_lib.ModeKeys.INFER) + @experimental + def export_savedmodel( + self, export_dir_base, input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + """Exports inference graph as a SavedModel into given dir. + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + input_fn: A function that takes no argument and + returns an `InputFnOps`. + default_output_alternative_key: the name of the head to serve when none is + specified. + assets_extra: A dict specifying how to populate the assets.extra directory + within the exported SavedModel. Each key should give the destination + path (including the filename) relative to the assets.extra directory. + The corresponding value gives the full path of the source file to be + copied. For example, the simple case of copying a single file without + renaming it is specified as + `{'my_asset_file.txt': '/path/to/my_asset_file.txt'}`. + as_text: whether to write the SavedModel proto in text format. + exports_to_keep: Number of exports to keep. + + Returns: + The string path to the exported directory. + + Raises: + ValueError: if an unrecognized export_type is requested. + """ + if input_fn is None: + raise ValueError('input_fn must be defined.') + + with ops.Graph().as_default() as g: + contrib_variables.create_global_step(g) + + # Call the input_fn and collect the input alternatives. + input_ops = input_fn() + input_alternatives, features = ( + saved_model_export_utils.get_input_alternatives(input_ops)) + + # Call the model_fn and collect the output alternatives. + model_fn_ops = self._call_model_fn(features, None, + model_fn_lib.ModeKeys.INFER) + output_alternatives, actual_default_output_alternative_key = ( + saved_model_export_utils.get_output_alternatives( + model_fn_ops, default_output_alternative_key)) + + # Build the SignatureDefs from all pairs of input and output signatures + signature_def_map = saved_model_export_utils.build_all_signature_defs( + input_alternatives, output_alternatives, + actual_default_output_alternative_key) + + # Locate the latest checkpoint + # TODO(soergel): does it help that we know we have one from this step? + checkpoint_path = saver.latest_checkpoint(self._model_dir) + if not checkpoint_path: + raise NotFittedError("Couldn't find trained model at %s." + % self._model_dir) + + export_dir = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + + with tf_session.Session('') as session: + variables.initialize_local_variables() + data_flow_ops.initialize_all_tables() + saver_for_restore = saver.Saver( + variables.global_variables(), + sharded=True) + saver_for_restore.restore(session, checkpoint_path) + + init_op = control_flow_ops.group( + variables.local_variables_initializer(), + data_flow_ops.initialize_all_tables()) + + # Perform the export + builder = saved_model_builder.SavedModelBuilder(export_dir) + builder.add_meta_graph_and_variables( + session, [tag_constants.SERVING], + signature_def_map=signature_def_map, + assets_collection=ops.get_collection( + ops.GraphKeys.ASSET_FILEPATHS), + legacy_init_op=init_op) + builder.save(as_text) + + # Add the extra assets + if assets_extra: + assets_extra_path = os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets.extra')) + for dest_relative, source in assets_extra.items(): + dest_absolute = os.path.join(compat.as_bytes(assets_extra_path), + compat.as_bytes(dest_relative)) + dest_path = os.path.dirname(dest_absolute) + gfile.MakeDirs(dest_path) + gfile.Copy(source, dest_absolute) + + return export_dir + # For time of deprecation x,y from Estimator allow direct access. # pylint: disable=protected-access diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py index 7d500bce19..49c917a3de 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py +++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py @@ -22,6 +22,7 @@ from __future__ import print_function import functools import itertools import json +import os import tempfile import numpy as np @@ -33,6 +34,11 @@ from tensorflow.contrib.learn.python.learn import metric_spec from tensorflow.contrib.learn.python.learn.estimators import _sklearn from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.framework import ops +from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import tag_constants +from tensorflow.python.util import compat _BOSTON_INPUT_DIM = 13 @@ -92,6 +98,8 @@ def linear_model_fn(features, labels, mode): tf.contrib.learn.ModeKeys.TRAIN, tf.contrib.learn.ModeKeys.EVAL, tf.contrib.learn.ModeKeys.INFER) + if isinstance(features, dict): + (_, features), = features.items() prediction, loss = ( tf.contrib.learn.models.linear_regression_zero_init(features, labels) ) @@ -131,6 +139,45 @@ def logistic_model_no_mode_fn(features, labels): learning_rate=0.1) return {'class': tf.argmax(prediction, 1), 'prob': prediction}, loss, train_op +VOCAB_FILE_CONTENT = 'emerson\nlake\npalmer\n' +EXTRA_FILE_CONTENT = 'kermit\npiggy\nralph\n' + + +def _build_estimator_for_export_tests(tmpdir): + def _input_fn(): + iris = tf.contrib.learn.datasets.load_iris() + return { + 'feature': tf.constant(iris.data, dtype=tf.float32) + }, tf.constant(iris.target, shape=[150], dtype=tf.int32) + + feature_columns = [tf.contrib.layers.real_valued_column('feature', + dimension=4)] + + est = tf.contrib.learn.LinearRegressor(feature_columns) + est.fit(input_fn=_input_fn, steps=20) + + feature_spec = tf.contrib.layers.create_feature_spec_for_parsing( + feature_columns) + export_input_fn = input_fn_utils.build_parsing_serving_input_fn(feature_spec) + + # hack in an op that uses an asset, in order to test asset export. + # this is not actually valid, of course. + def export_input_fn_with_asset(): + features, labels, inputs = export_input_fn() + + vocab_file_name = os.path.join(tmpdir, 'my_vocab_file') + vocab_file = tf.gfile.GFile(vocab_file_name, mode='w') + vocab_file.write(VOCAB_FILE_CONTENT) + vocab_file.close() + hashtable = tf.contrib.lookup.HashTable( + tf.contrib.lookup.TextFileStringTableInitializer(vocab_file_name), 'x') + features['bogus_lookup'] = hashtable.lookup( + tf.to_int64(features['feature'])) + + return input_fn_utils.InputFnOps(features, labels, inputs) + + return est, export_input_fn_with_asset + class CheckCallsMonitor(tf.contrib.learn.monitors.BaseMonitor): @@ -503,6 +550,76 @@ class EstimatorTest(tf.test.TestCase): self.assertEquals(expected, actual) + def test_export_savedmodel(self): + tmpdir = tempfile.mkdtemp() + est, export_input_fn = _build_estimator_for_export_tests(tmpdir) + + extra_file_name = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('my_extra_file')) + extra_file = tf.gfile.GFile(extra_file_name, mode='w') + extra_file.write(EXTRA_FILE_CONTENT) + extra_file.close() + assets_extra = {'some/sub/directory/my_extra_file': extra_file_name} + + export_dir_base = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('export')) + export_dir = est.export_savedmodel(export_dir_base, export_input_fn, + assets_extra=assets_extra) + + self.assertTrue(tf.gfile.Exists(export_dir_base)) + self.assertTrue(tf.gfile.Exists(export_dir)) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('saved_model.pb')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('variables')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.index')))) + self.assertTrue(tf.gfile.Exists(os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('variables/variables.data-00000-of-00001')))) + + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), compat.as_bytes('assets')))) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets/my_vocab_file')))) + self.assertEqual( + compat.as_bytes(VOCAB_FILE_CONTENT), + compat.as_bytes(tf.gfile.GFile( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets/my_vocab_file'))).read())) + + expected_extra_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes('assets.extra/some/sub/directory/my_extra_file')) + self.assertTrue(tf.gfile.Exists( + os.path.join(compat.as_bytes(export_dir), + compat.as_bytes('assets.extra')))) + self.assertTrue(tf.gfile.Exists(expected_extra_path)) + self.assertEqual( + compat.as_bytes(EXTRA_FILE_CONTENT), + compat.as_bytes(tf.gfile.GFile(expected_extra_path).read())) + + expected_vocab_file = os.path.join(compat.as_bytes(tmpdir), + compat.as_bytes('my_vocab_file')) + # Restore, to validate that the export was well-formed. + with tf.Graph().as_default() as graph: + with tf.Session(graph=graph) as sess: + loader.load(sess, [tag_constants.SERVING], export_dir) + assets = [x.eval() + for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)] + self.assertItemsEqual([expected_vocab_file], assets) + graph_ops = [x.name for x in graph.get_operations()] + self.assertTrue('input_example_tensor' in graph_ops) + self.assertTrue('ParseExample/ParseExample' in graph_ops) + self.assertTrue('linear/linear/feature/matmul' in graph_ops) + + # cleanup + tf.gfile.DeleteRecursively(tmpdir) + class InferRealValuedColumnsTest(tf.test.TestCase): diff --git a/tensorflow/contrib/learn/python/learn/estimators/head.py b/tensorflow/contrib/learn/python/learn/estimators/head.py index 2b713efb85..8e0a4de19c 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/head.py +++ b/tensorflow/contrib/learn/python/learn/estimators/head.py @@ -24,6 +24,7 @@ import six from tensorflow.contrib import losses from tensorflow.contrib import metrics as metrics_lib from tensorflow.contrib.learn.python.learn import metric_spec +from tensorflow.contrib.learn.python.learn.estimators import constants from tensorflow.contrib.learn.python.learn.estimators import estimator from tensorflow.contrib.learn.python.learn.estimators import metric_key from tensorflow.contrib.learn.python.learn.estimators import model_fn @@ -200,6 +201,9 @@ class _Head(object): """ __metaclass__ = abc.ABCMeta + def __init__(self, head_name): + self._head_name = head_name + @abc.abstractproperty def logits_dimension(self): raise NotImplementedError("Calling an abstract method.") @@ -227,6 +231,25 @@ class _Head(object): """ raise NotImplementedError("Calling an abstract method.") + def _create_output_alternatives(self, predictions): + """Creates output alternative for the Head. + + Args: + predictions: a dict of {tensor_name: Tensor}, where 'tensor_name' is a + symbolic name for an output Tensor possibly but not necessarily taken from + `PredictionKey`, and 'Tensor' is the corresponding output Tensor itself. + + Returns: + a dict of {submodel_name: (problem_type, {tensor_name: Tensor})}, where + 'submodel_name' is a submodel identifier that should be consistent across + the pipeline (here likely taken from the head_name), + 'problem_type' is a `ProblemType`, + 'tensor_name' is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and + 'Tensor' is the corresponding output Tensor itself. + """ + return {self._head_name: (self._problem_type, predictions)} + class _RegressionHead(_Head): """_Head for regression.""" @@ -249,14 +272,16 @@ class _RegressionHead(_Head): head_name: name of the head. If provided, predictions, summary and metrics keys will be prefixed by the head_name and an underscore. """ + super(_RegressionHead, self).__init__(head_name=head_name) + self._loss_fn = loss_fn self._logits_dimension = label_dimension self._label_name = label_name self._weight_column_name = weight_column_name - self._head_name = head_name self._enable_centered_bias = enable_centered_bias self._centered_bias_weight_collection = _head_prefixed(head_name, "centered_bias") + self._problem_type = constants.ProblemType.LINEAR_REGRESSION @property def logits_dimension(self): @@ -289,7 +314,8 @@ class _RegressionHead(_Head): loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=self._signature_fn()) + signature_fn=self._signature_fn(), + output_alternatives=self._create_output_alternatives(predictions)) def _training_loss(self, features, labels, logits, name=None): """Returns training loss tensor for this head. @@ -450,6 +476,8 @@ class _BinaryLogisticHead(_Head): Raises: ValueError: if n_classes is invalid. """ + super(_BinaryLogisticHead, self).__init__(head_name=head_name) + self._thresholds = thresholds if thresholds else [.5] self._label_name = label_name self._weight_column_name = weight_column_name @@ -705,6 +733,8 @@ class _MultiClassHead(_Head): Raises: ValueError: if n_classes is invalid. """ + super(_MultiClassHead, self).__init__(head_name=head_name) + if (n_classes is None) or (n_classes <= 2): raise ValueError("n_classes must be > 2: %s." % n_classes) self._thresholds = thresholds if thresholds else [.5] @@ -712,11 +742,11 @@ class _MultiClassHead(_Head): self._logits_dimension = n_classes self._label_name = label_name self._weight_column_name = weight_column_name - self._head_name = head_name self._loss_fn = loss_fn self._enable_centered_bias = enable_centered_bias self._centered_bias_weight_collection = _head_prefixed(head_name, "centered_bias") + self._problem_type = constants.ProblemType.CLASSIFICATION @property def logits_dimension(self): @@ -749,7 +779,8 @@ class _MultiClassHead(_Head): loss=loss, train_op=train_op, eval_metric_ops=eval_metric_ops, - signature_fn=self._signature_fn()) + signature_fn=self._signature_fn(), + output_alternatives=self._create_output_alternatives(predictions)) def _training_loss(self, features, labels, logits=None, name=None): """Returns training loss tensor for this head. diff --git a/tensorflow/contrib/learn/python/learn/estimators/linear.py b/tensorflow/contrib/learn/python/learn/estimators/linear.py index 763d70ddaf..7e0b9c36f6 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/linear.py +++ b/tensorflow/contrib/learn/python/learn/estimators/linear.py @@ -27,6 +27,7 @@ import six from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.framework.python.ops import variables as contrib_variables from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import monitors as monitor_lib @@ -519,6 +520,22 @@ class LinearClassifier(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property @deprecated("2016-10-30", "This method will be removed after the deprecation date. " @@ -760,6 +777,22 @@ class LinearRegressor(evaluable.Evaluable, trainable.Trainable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property @deprecated("2016-10-30", "This method will be removed after the deprecation date. " diff --git a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py index 3f9351ce22..42f21bd196 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/model_fn.py +++ b/tensorflow/contrib/learn/python/learn/estimators/model_fn.py @@ -49,13 +49,33 @@ class ModeKeys(object): # TODO(roumposg): Pass output_signature_fn instead of signature_fn. class ModelFnOps(collections.namedtuple( 'ModelFnOps', - ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn'])): + ['predictions', 'loss', 'train_op', 'eval_metric_ops', 'signature_fn', + 'output_alternatives'])): """Ops returned from a model_fn.""" + # TODO(soergel): remove signature_fn once sessionbundle export is deprecated. + def __new__(cls, mode, predictions=None, loss=None, train_op=None, - eval_metric_ops=None, signature_fn=None): + eval_metric_ops=None, signature_fn=None, + output_alternatives=None): """Creates a validated `ModelFnOps` instance. + For a multi-headed model, the predictions dict here will contain the outputs + of all of the heads. However: at serving time, requests will be made + specifically for one or more heads, and the RPCs used for these requests may + differ by problem type (i.e., regression, classification, other). The + purpose of the output_alternatives dict is to aid in exporting a SavedModel + from which such head-specific queries can be served. These + output_alternatives will be combined with input_alternatives (see + `saved_model_export_utils`) to produce a set of `SignatureDef`s specifying + the valid requests that can be served from this model. + + For a single-headed model, it is still adviseable to provide + output_alternatives with a single entry, because this is how the problem + type is communicated for export and serving. If output_alternatives is not + given, the resulting SavedModel will support only one head of unspecified + type. + Args: mode: One of `ModeKeys`. Specifies if this training, evaluation or prediction. @@ -65,6 +85,14 @@ class ModelFnOps(collections.namedtuple( eval_metric_ops: Dict of metric results keyed by name. The values of the dict are the results of calling a metric function, such as `Tensor`. signature_fn: The signature_fn used for exporting. + output_alternatives: a dict of + `{submodel_name: (problem_type, {tensor_name: Tensor})}`, where + `submodel_name` is a submodel identifier that should be consistent + across the pipeline (here likely taken from the name of each `Head`, + for models that use them), `problem_type` is a `ProblemType`, + `tensor_name` is a symbolic name for an output Tensor possibly but not + necessarily taken from `PredictionKey`, and `Tensor` is the + corresponding output Tensor itself. Returns: A validated `ModelFnOps` object. @@ -122,4 +150,5 @@ class ModelFnOps(collections.namedtuple( raise ValueError('signature_fn is not callable.') return super(ModelFnOps, cls).__new__(cls, predictions, loss, train_op, - eval_metric_ops, signature_fn) + eval_metric_ops, signature_fn, + output_alternatives) diff --git a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py index c2c41255c9..deb55efc9f 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/random_forest.py +++ b/tensorflow/contrib/learn/python/learn/estimators/random_forest.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.contrib import framework as contrib_framework from tensorflow.contrib.framework import deprecated_arg_values +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable @@ -352,3 +353,19 @@ class TensorForestEstimator(evaluable.Evaluable, trainable.Trainable): self._estimator._model_fn = orig_model_fn # pylint: enable=protected-access return result + + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) diff --git a/tensorflow/contrib/learn/python/learn/estimators/svm.py b/tensorflow/contrib/learn/python/learn/estimators/svm.py index eeee673c5a..a6e4e7b6a3 100644 --- a/tensorflow/contrib/learn/python/learn/estimators/svm.py +++ b/tensorflow/contrib/learn/python/learn/estimators/svm.py @@ -19,12 +19,14 @@ from __future__ import division from __future__ import print_function import inspect +import re import tempfile from tensorflow.contrib import layers from tensorflow.contrib.framework import deprecated_arg_values from tensorflow.contrib.framework import list_variables from tensorflow.contrib.framework import load_variable +from tensorflow.contrib.framework.python.framework import experimental from tensorflow.contrib.learn.python.learn import evaluable from tensorflow.contrib.learn.python.learn import trainable from tensorflow.contrib.learn.python.learn.estimators import estimator @@ -235,6 +237,22 @@ class SVM(trainable.Trainable, evaluable.Evaluable): default_batch_size=default_batch_size, exports_to_keep=exports_to_keep) + @experimental + def export_savedmodel(self, + export_dir_base, + input_fn, + default_output_alternative_key=None, + assets_extra=None, + as_text=False, + exports_to_keep=None): + return self._estimator.export_savedmodel( + export_dir_base, + input_fn, + default_output_alternative_key=default_output_alternative_key, + assets_extra=assets_extra, + as_text=as_text, + exports_to_keep=exports_to_keep) + @property def weights_(self): values = {} diff --git a/tensorflow/contrib/learn/python/learn/utils/gc.py b/tensorflow/contrib/learn/python/learn/utils/gc.py new file mode 100644 index 0000000000..dd4376f051 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/gc.py @@ -0,0 +1,205 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +r"""System for specifying garbage collection (GC) of path based data. + +This framework allows for GC of data specified by path names, for example files +on disk. gc.Path objects each represent a single item stored at a path and may +be a base directory, + /tmp/exports/0/... + /tmp/exports/1/... + ... +or a fully qualified file, + /tmp/train-1.ckpt + /tmp/train-2.ckpt + ... + +A gc filter function takes and returns a list of gc.Path items. Filter +functions are responsible for selecting Path items for preservation or deletion. +Note that functions should always return a sorted list. + +For example, + base_dir = "/tmp" + # create the directories + for e in xrange(10): + os.mkdir("%s/%d" % (base_dir, e), 0o755) + + # create a simple parser that pulls the export_version from the directory + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + path_list = gc.get_paths("/tmp", parser) # contains all ten Paths + + every_fifth = gc.mod_export_version(5) + print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"] + + largest_three = gc.largest_export_versions(3) + print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] + + both = gc.union(every_fifth, largest_three) + print both(all_paths) # shows ["/tmp/0", "/tmp/5", + # "/tmp/7", "/tmp/8", "/tmp/9"] + # delete everything not in 'both' + to_delete = gc.negation(both) + for p in to_delete(all_paths): + gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + # "/tmp/3", "/tmp/4", "/tmp/6", +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import heapq +import math +import os + +from tensorflow.python.platform import gfile + +Path = collections.namedtuple('Path', 'path export_version') + + +def largest_export_versions(n): + """Creates a filter that keeps the largest n export versions. + + Args: + n: number of versions to keep. + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + heap = [] + for idx, path in enumerate(paths): + if path.export_version is not None: + heapq.heappush(heap, (path.export_version, idx)) + keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] + return sorted(keepers) + + return keep + + +def one_of_every_n_export_versions(n): + """Creates a filter that keeps one of every n export versions. + + Args: + n: interval size. + + Returns: + A filter function that keeps exactly one path from each interval + [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an + interval the largest is kept. + """ + def keep(paths): + """A filter function that keeps exactly one out of every n paths.""" + + keeper_map = {} # map from interval to largest path seen in that interval + for p in paths: + if p.export_version is None: + # Skip missing export_versions. + continue + # Find the interval (with a special case to map export_version = 0 to + # interval 0. + interval = math.floor( + (p.export_version - 1) / n) if p.export_version else 0 + existing = keeper_map.get(interval, None) + if (not existing) or (existing.export_version < p.export_version): + keeper_map[interval] = p + return sorted(keeper_map.values()) + + return keep + + +def mod_export_version(n): + """Creates a filter that keeps every export that is a multiple of n. + + Args: + n: step size. + + Returns: + A filter function that keeps paths where export_version % n == 0. + """ + def keep(paths): + keepers = [] + for p in paths: + if p.export_version % n == 0: + keepers.append(p) + return sorted(keepers) + return keep + + +def union(lf, rf): + """Creates a filter that keeps the union of two filters. + + Args: + lf: first filter + rf: second filter + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + l = set(lf(paths)) + r = set(rf(paths)) + return sorted(list(l|r)) + return keep + + +def negation(f): + """Negate a filter. + + Args: + f: filter function to invert + + Returns: + A filter function that returns the negation of f. + """ + def keep(paths): + l = set(paths) + r = set(f(paths)) + return sorted(list(l-r)) + return keep + + +def get_paths(base_dir, parser): + """Gets a list of Paths in a given directory. + + Args: + base_dir: directory. + parser: a function which gets the raw Path and can augment it with + information such as the export_version, or ignore the path by returning + None. An example parser may extract the export version from a path + such as "/tmp/exports/100" an another may extract from a full file + name such as "/tmp/checkpoint-99.out". + + Returns: + A list of Paths contained in the base directory with the parsing function + applied. + By default the following fields are populated, + - Path.path + The parsing function is responsible for populating, + - Path.export_version + """ + raw_paths = gfile.ListDirectory(base_dir) + paths = [] + for r in raw_paths: + p = parser(Path(os.path.join(base_dir, r), None)) + if p: + paths.append(p) + return sorted(paths) diff --git a/tensorflow/contrib/learn/python/learn/utils/gc_test.py b/tensorflow/contrib/learn/python/learn/utils/gc_test.py new file mode 100644 index 0000000000..dbe3304f21 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/gc_test.py @@ -0,0 +1,120 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for learn.utils.gc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn.utils import gc +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile + + +def tearDownModule(): + gfile.DeleteRecursively(tf.test.get_temp_dir()) + + +class GcTest(test_util.TensorFlowTestCase): + + def testLargestExportVersions(self): + paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) + + def testLargestExportVersionsDoesNotDeleteZeroFolder(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)]) + + def testModExportVersion(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.mod_export_version(2) + self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) + mod = gc.mod_export_version(3) + self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) + + def testOneOfEveryNExportVersions(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), + gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 33)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 3), gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33)]) + + def testOneOfEveryNExportVersionsZero(self): + # Zero is a special case since it gets rolled into the first interval. + # Test that here. + paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 0), gc.Path("/foo", 5)]) + + def testUnion(self): + paths = [] + for i in xrange(10): + paths.append(gc.Path("/foo", i)) + f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) + self.assertEquals( + f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), + gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 9)]) + + def testNegation(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.negation(gc.mod_export_version(2)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) + mod = gc.negation(gc.mod_export_version(3)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) + + def testPathsWithParse(self): + base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") + self.assertFalse(gfile.Exists(base_dir)) + for p in xrange(3): + gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) + # add a base_directory to ignore + gfile.MakeDirs(os.path.join(base_dir, "ignore")) + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + self.assertEquals( + gc.get_paths(base_dir, parser=parser), + [gc.Path(os.path.join(base_dir, "0"), 0), + gc.Path(os.path.join(base_dir, "1"), 1), + gc.Path(os.path.join(base_dir, "2"), 2)]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py new file mode 100644 index 0000000000..2cb7173d5a --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/input_fn_utils.py @@ -0,0 +1,97 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities for creating input_fns.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import parsing_ops + + +# A return type allowing input_fns to return multiple values in a well- +# defined way (analogous to ModelFnOps). +# The expected return values are: +# features: a dict of string to Tensor, giving the features to be passed to +# the model. +# labels: a dict of string to Tensor, giving labels (aka targets) for training. +# default_inputs: a dict of string to Tensor, giving the input Tensors (if +# any) that this input_fn expects to be fed. +InputFnOps = collections.namedtuple('InputFnOps', + ['features', + 'labels', + 'default_inputs']) + + +def build_parsing_serving_input_fn(feature_spec, default_batch_size=1): + """Build an input_fn appropriate for serving, expecting fed tf.Examples. + + Creates an input_fn that expects a serialized tf.Example fed into a string + placeholder. The function parses the tf.Example according to the provided + feature_spec, and returns all parsed Tensors as features. This input_fn is + for use at serving time, so the labels return value is always None. + + Args: + feature_spec: a dict of string to `VarLenFeature`/`FixedLenFeature`. + default_batch_size: the number of query examples expected per batch. + + Returns: + An input_fn suitable for use in serving. + """ + def input_fn(): + """An input_fn that expects a serialized tf.Example.""" + serialized_tf_example = array_ops.placeholder(dtype=dtypes.string, + shape=[default_batch_size], + name='input_example_tensor') + inputs = {'examples': serialized_tf_example} + features = parsing_ops.parse_example(serialized_tf_example, feature_spec) + labels = None # these are not known in serving! + return InputFnOps(features, labels, inputs) + return input_fn + + +def build_default_serving_input_fn(features, default_batch_size=1): + """Build an input_fn appropriate for serving, expecting feature Tensors. + + Creates an input_fn that expects all features to be fed directly. + This input_fn is for use at serving time, so the labels return value is always + None. + + Args: + features: a dict of string to `Tensor`. + default_batch_size: the number of query examples expected per batch. + + Returns: + An input_fn suitable for use in serving. + """ + def input_fn(): + """an input_fn that expects all features to be fed directly.""" + features_placeholders = {} + for name, t in features.items(): + shape_list = t.get_shape().as_list() + shape_list[0] = default_batch_size + shape = tensor_shape.TensorShape(shape_list) + + features_placeholders[name] = array_ops.placeholder(dtype=t.dtype, + shape=shape, + name=t.name) + labels = None # these are not known in serving! + return InputFnOps(features_placeholders, labels, features_placeholders) + return input_fn diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py new file mode 100644 index 0000000000..54bb0fb3d7 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py @@ -0,0 +1,248 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Utilities supporting export to SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import copy +import os +import re +import time + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import prediction_key +from tensorflow.contrib.learn.python.learn.utils import gc +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.python.platform import gfile +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils + +from tensorflow.python.util import compat + +# A key for use in the input_alternatives dict indicating the default input. +# This is the input that will be expected when a serving request does not +# specify a specific signature. +# The default input alternative specifies placeholders that the input_fn +# requires to be fed (in the typical case, a single placeholder for a +# serialized tf.Example). +DEFAULT_INPUT_ALTERNATIVE_KEY = 'default_input_alternative' + +# A key for use in the input_alternatives dict indicating the features input. +# The features inputs alternative specifies the feature Tensors provided as +# input to the model_fn, i.e. the outputs of the input_fn. +FEATURES_INPUT_ALTERNATIVE_KEY = 'features_input_alternative' + +# A key for use in the output_alternatives dict indicating the default output. +# This is the output that will be provided when a serving request does not +# specify a specific signature. +# In a single-headed model, the single output is automatically the default. +# In a multi-headed model, the name of the desired default head should be +# provided to get_output_alternatives. +DEFAULT_OUTPUT_ALTERNATIVE_KEY = 'default_output_alternative' + + +def build_standardized_signature_def( + input_tensors, output_tensors, problem_type): + """Build a SignatureDef using problem type and input and output Tensors. + + Note that this delegates the actual creation of the signatures to methods in + //third_party/tensorflow/python/saved_model/signature_def_utils.py, which may + assign names to the input and output tensors (depending on the problem type) + that are standardized in the context of SavedModel. + + Args: + input_tensors: a dict of string key to `Tensor` + output_tensors: a dict of string key to `Tensor` + problem_type: an instance of constants.ProblemType, specifying + classification, regression, etc. + + Returns: + A SignatureDef using SavedModel standard keys where possible. + + Raises: + ValueError: if input_tensors or output_tensors is None or empty. + """ + + if not input_tensors: + raise ValueError('input_tensors must be provided.') + if not output_tensors: + raise ValueError('output_tensors must be provided.') + + # Per-method signature_def functions will standardize the keys if possible + if _is_classification_problem(problem_type, input_tensors, output_tensors): + (_, examples), = input_tensors.items() + classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) + scores = output_tensors.get(prediction_key.PredictionKey.SCORES) + if not (classes or scores): + (_, classes), = output_tensors.items() + return signature_def_utils.classification_signature_def( + examples, classes, scores) + elif _is_regression_problem(problem_type, input_tensors, output_tensors): + (_, examples), = input_tensors.items() + (_, predictions), = output_tensors.items() + return signature_def_utils.regression_signature_def(examples, predictions) + else: + return signature_def_utils.predict_signature_def( + input_tensors, output_tensors) + + +def _is_classification_problem(problem_type, input_tensors, output_tensors): + classes = output_tensors.get(prediction_key.PredictionKey.CLASSES) + scores = output_tensors.get(prediction_key.PredictionKey.SCORES) + return ((problem_type == constants.ProblemType.CLASSIFICATION or + problem_type == constants.ProblemType.LOGISTIC_REGRESSION) + and len(input_tensors) == 1 + and (classes or scores or len(output_tensors) == 1)) + + +def _is_regression_problem(problem_type, input_tensors, output_tensors): + return (problem_type == constants.ProblemType.LINEAR_REGRESSION + and len(input_tensors) == 1 + and len(output_tensors) == 1) + + +def get_input_alternatives(input_ops): + """Obtain all input alternatives using the input_fn output and heuristics.""" + input_alternatives = {} + if isinstance(input_ops, input_fn_utils.InputFnOps): + features, unused_labels, default_inputs = input_ops + input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] = default_inputs + else: + features, unused_labels = input_ops + + if not features: + raise ValueError('Features must be defined.') + + # Add the "features" input_signature in any case. + # Note defensive copy because model_fns alter the features dict. + input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] = ( + copy.copy(features)) + + return input_alternatives, features + + +def get_output_alternatives( + model_fn_ops, + default_output_alternative_key=DEFAULT_OUTPUT_ALTERNATIVE_KEY): + """Obtain all output alternatives using the model_fn output and heuristics.""" + output_alternatives = model_fn_ops.output_alternatives + + # Identify the default outputs, creating them if needed. + if (output_alternatives + and default_output_alternative_key not in output_alternatives): + raise ValueError('default_output_alternative_key not in ' + 'output_alternatives: %s' % default_output_alternative_key) + + if (output_alternatives + and default_output_alternative_key in output_alternatives): + # If a default head is provided, use it. + actual_default_output_alternative_key = default_output_alternative_key + return output_alternatives, actual_default_output_alternative_key + + if output_alternatives and len(output_alternatives) == 1: + # If there is only one head, use it as the default. + (actual_default_output_alternative_key, _), = output_alternatives.items() + return output_alternatives, actual_default_output_alternative_key + + # Lacking provided output alternatives, the best we can do is to + # interpret the model as single-headed of unknown type. + default_problem_type = constants.ProblemType.UNSPECIFIED + default_outputs = model_fn_ops.predictions + actual_default_output_alternative_key = DEFAULT_OUTPUT_ALTERNATIVE_KEY + output_alternatives = {actual_default_output_alternative_key: + (default_problem_type, default_outputs)} + return output_alternatives, actual_default_output_alternative_key + + +def build_all_signature_defs(input_alternatives, output_alternatives, + actual_default_output_alternative_key): + """Build `SignatureDef`s from all pairs of input and output alternatives.""" + + signature_def_map = { + ('%s:%s' % (input_key, output_key or 'None')): + build_standardized_signature_def( + inputs, outputs, problem_type) + for input_key, inputs in input_alternatives.items() + for output_key, (problem_type, outputs) + in output_alternatives.items()} + + # Add the default SignatureDef + default_inputs = input_alternatives[DEFAULT_INPUT_ALTERNATIVE_KEY] + if not default_inputs: + default_inputs = input_alternatives[FEATURES_INPUT_ALTERNATIVE_KEY] + # default outputs are guaranteed to exist above + (default_problem_type, default_outputs) = ( + output_alternatives[actual_default_output_alternative_key]) + signature_def_map[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( + build_standardized_signature_def( + default_inputs, default_outputs, default_problem_type)) + + return signature_def_map + + +def get_timestamped_export_dir(export_dir_base): + """Builds a path to a new subdirectory within the base directory. + + Each export is written into a new subdirectory named using the + current time. This guarantees monotonically increasing version + numbers even across multiple runs of the pipeline. + The timestamp used is the number of milliseconds since epoch UTC. + + Args: + export_dir_base: A string containing a directory to write the exported + graph and checkpoints. + Returns: + The full path of the new subdirectory (which is not actually created yet). + """ + export_timestamp = int(time.time() * 1e3) + + export_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(str(export_timestamp))) + return export_dir + + +def garbage_collect_exports(export_dir_base, exports_to_keep): + """Deletes older exports, retaining only a given number of the most recent. + + Export subdirectories are assumed to be named with monotonically increasing + integers; the most recent are taken to be those with the largest values. + + Args: + export_dir_base: the base directory under which each export is in a + versioned subdirectory. + exports_to_keep: the number of recent exports to retain. + """ + if exports_to_keep is None: + return + + keep_filter = gc.largest_export_versions(exports_to_keep) + delete_filter = gc.negation(keep_filter) + + # Export dir must not end with / or it will break the re match below. + if export_dir_base.endswith('/'): + export_dir_base = export_dir_base[:-1] + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match('^' + export_dir_base + '/(\\d{13})$', path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + for p in delete_filter(gc.get_paths(export_dir_base, parser=parser)): + gfile.DeleteRecursively(p.path) diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py new file mode 100644 index 0000000000..538e0ab104 --- /dev/null +++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py @@ -0,0 +1,228 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests of utilities supporting export to SavedModel.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import tempfile +import time + +import tensorflow as tf + +from tensorflow.contrib.learn.python.learn.estimators import constants +from tensorflow.contrib.learn.python.learn.estimators import model_fn +from tensorflow.contrib.learn.python.learn.utils import input_fn_utils +from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils +from tensorflow.core.framework import tensor_shape_pb2 +from tensorflow.core.framework import types_pb2 +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils + + +class SavedModelExportUtilsTest(tf.test.TestCase): + + def test_build_standardized_signature_def(self): + input_tensors = { + "input-1": tf.placeholder(tf.float32, 1, name="input-tensor-1")} + output_tensors = { + "output-1": tf.placeholder(tf.float32, 1, name="output-tensor-1")} + problem_type = constants.ProblemType.LINEAR_REGRESSION + regression_signature_def = ( + saved_model_export_utils.build_standardized_signature_def( + input_tensors, output_tensors, problem_type)) + expected_regression_signature_def = meta_graph_pb2.SignatureDef() + shape = tensor_shape_pb2.TensorShapeProto( + dim=[tensor_shape_pb2.TensorShapeProto.Dim(size=1)]) + dtype = types_pb2.DataType.Value("DT_FLOAT") + expected_regression_signature_def.inputs[ + signature_constants.REGRESS_INPUTS].CopyFrom( + meta_graph_pb2.TensorInfo(name="input-tensor-1:0", + dtype=dtype, + tensor_shape=shape)) + expected_regression_signature_def.outputs[ + signature_constants.REGRESS_OUTPUTS].CopyFrom( + meta_graph_pb2.TensorInfo(name="output-tensor-1:0", + dtype=dtype, + tensor_shape=shape)) + + expected_regression_signature_def.method_name = ( + signature_constants.REGRESS_METHOD_NAME) + self.assertEqual(regression_signature_def, + expected_regression_signature_def) + + def test_get_input_alternatives(self): + input_ops = input_fn_utils.InputFnOps("bogus features dict", None, + "bogus default input dict") + + input_alternatives, _ = saved_model_export_utils.get_input_alternatives( + input_ops) + self.assertEqual( + input_alternatives[ + saved_model_export_utils.DEFAULT_INPUT_ALTERNATIVE_KEY], + "bogus default input dict") + self.assertEqual( + input_alternatives[ + saved_model_export_utils.FEATURES_INPUT_ALTERNATIVE_KEY], + "bogus features dict") + + def test_get_output_alternatives_explicit(self): + provided_output_alternatives = { + "head-1": (constants.ProblemType.LINEAR_REGRESSION, + "bogus output dict"), + "head-2": (constants.ProblemType.CLASSIFICATION, + "bogus output dict 2"), + "head-3": (constants.ProblemType.UNSPECIFIED, + "bogus output dict 3"), + } + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": "bogus_tensor"}, + output_alternatives=provided_output_alternatives) + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops, "head-1") + + self.assertEqual(provided_output_alternatives, output_alternatives) + + def test_get_output_alternatives_implicit(self): + prediction_tensor = tf.constant(["bogus"]) + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": prediction_tensor}, + output_alternatives=None) + + output_alternatives, _ = saved_model_export_utils.get_output_alternatives( + model_fn_ops, "some_output") + self.assertEqual( + {"default_output_alternative": (constants.ProblemType.UNSPECIFIED, + {"some_output": prediction_tensor})}, + output_alternatives) + + def test_build_all_signature_defs(self): + input_features = tf.constant(["10"]) + input_example = tf.constant(["11"]) + input_ops = input_fn_utils.InputFnOps( + {"features": input_features}, + None, + {"default input": input_example}) + input_alternatives, _ = ( + saved_model_export_utils.get_input_alternatives(input_ops)) + output_1 = tf.constant(["1"]) + output_2 = tf.constant(["2"]) + output_3 = tf.constant(["3"]) + provided_output_alternatives = { + "head-1": (constants.ProblemType.LINEAR_REGRESSION, + {"some_output_1": output_1}), + "head-2": (constants.ProblemType.CLASSIFICATION, + {"some_output_2": output_2}), + "head-3": (constants.ProblemType.UNSPECIFIED, + {"some_output_3": output_3}), + } + model_fn_ops = model_fn.ModelFnOps( + model_fn.ModeKeys.INFER, + predictions={"some_output": tf.constant(["4"])}, + output_alternatives=provided_output_alternatives) + output_alternatives, _ = ( + saved_model_export_utils.get_output_alternatives(model_fn_ops, + "head-1")) + + signature_defs = saved_model_export_utils.build_all_signature_defs( + input_alternatives, output_alternatives, "head-1") + + expected_signature_defs = { + "serving_default": + signature_def_utils.regression_signature_def( + input_example, output_1), + "default_input_alternative:head-1": + signature_def_utils.regression_signature_def( + input_example, output_1), + "default_input_alternative:head-2": + signature_def_utils.classification_signature_def( + input_example, output_2, None), + "default_input_alternative:head-3": + signature_def_utils.predict_signature_def( + {"input": input_example}, {"output": output_3}), + "features_input_alternative:head-1": + signature_def_utils.regression_signature_def( + input_features, output_1), + "features_input_alternative:head-2": + signature_def_utils.classification_signature_def( + input_features, output_2, None), + "features_input_alternative:head-3": + signature_def_utils.predict_signature_def( + {"input": input_features}, {"output": output_3}), + } + + self.assertDictEqual(expected_signature_defs, signature_defs) + + def test_get_timestamped_export_dir(self): + export_dir_base = tempfile.mkdtemp() + "export/" + export_dir_1 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + time.sleep(0.001) + export_dir_2 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + time.sleep(0.001) + export_dir_3 = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + + # Export directories should be named using a timestamp that is milliseconds + # since epoch. Such a timestamp is 13 digits long. + time_1 = os.path.basename(export_dir_1) + self.assertEqual(13, len(time_1)) + time_2 = os.path.basename(export_dir_2) + self.assertEqual(13, len(time_2)) + time_3 = os.path.basename(export_dir_3) + self.assertEqual(13, len(time_3)) + + self.assertTrue(int(time_1) < int(time_2)) + self.assertTrue(int(time_2) < int(time_3)) + + def test_garbage_collect_exports(self): + export_dir_base = tempfile.mkdtemp() + "export/" + tf.gfile.MkDir(export_dir_base) + export_dir_1 = _create_test_export_dir(export_dir_base) + export_dir_2 = _create_test_export_dir(export_dir_base) + export_dir_3 = _create_test_export_dir(export_dir_base) + export_dir_4 = _create_test_export_dir(export_dir_base) + + self.assertTrue(tf.gfile.Exists(export_dir_1)) + self.assertTrue(tf.gfile.Exists(export_dir_2)) + self.assertTrue(tf.gfile.Exists(export_dir_3)) + self.assertTrue(tf.gfile.Exists(export_dir_4)) + + # Garbage collect all but the most recent 2 exports, + # where recency is determined based on the timestamp directory names. + saved_model_export_utils.garbage_collect_exports(export_dir_base, 2) + + self.assertFalse(tf.gfile.Exists(export_dir_1)) + self.assertFalse(tf.gfile.Exists(export_dir_2)) + self.assertTrue(tf.gfile.Exists(export_dir_3)) + self.assertTrue(tf.gfile.Exists(export_dir_4)) + + +def _create_test_export_dir(export_dir_base): + export_dir = saved_model_export_utils.get_timestamped_export_dir( + export_dir_base) + tf.gfile.MkDir(export_dir) + time.sleep(0.001) + return export_dir + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index db5768acb8..1663a1f251 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -2,7 +2,10 @@ # TensorFlow SavedModel. package( - default_visibility = ["//tensorflow/python/saved_model:__subpackages__"], + default_visibility = [ + "//tensorflow/contrib/learn:__subpackages__", + "//tensorflow/python/saved_model:__subpackages__", + ], ) licenses(["notice"]) # Apache 2.0 @@ -33,7 +36,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", - "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:platform", @@ -48,7 +50,6 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", - "//tensorflow:tensorflow_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:platform", "//tensorflow/python:training", @@ -94,7 +95,9 @@ py_library( name = "utils", srcs = ["utils.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/core:protos_all_py"], + deps = [ + "//tensorflow/core:protos_all_py", + ], ) py_test( @@ -111,6 +114,31 @@ py_test( ], ) +py_library( + name = "signature_def_utils", + srcs = ["signature_def_utils.py"], + srcs_version = "PY2AND3", + deps = [ + ":signature_constants", + ":utils", + "//tensorflow/core:protos_all_py", + ], +) + +py_test( + name = "signature_def_utils_test", + size = "small", + srcs = [ + "signature_def_utils_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":signature_def_utils", + "//tensorflow:tensorflow_py", + ], +) + # ----------------------------------------------------------------------------- # Google-internal targets. These must be at the end for syncrepo. diff --git a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py index 0a05ff09ca..7927dfe632 100644 --- a/tensorflow/python/saved_model/example/saved_model_half_plus_two.py +++ b/tensorflow/python/saved_model/example/saved_model_half_plus_two.py @@ -36,8 +36,8 @@ from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants -from tensorflow.python.saved_model import utils from tensorflow.python.util import compat tf.app.flags.DEFINE_string("output_dir", "/tmp/saved_model_half_plus_two", @@ -121,7 +121,7 @@ def _generate_saved_model_for_half_plus_two(export_dir, as_text=False): output_tensor = meta_graph_pb2.TensorInfo() output_tensor.name = tf.identity(y).name signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor} - signature_def = utils.build_signature_def( + signature_def = signature_def_utils.build_signature_def( signature_inputs, signature_outputs, signature_constants.REGRESS_METHOD_NAME) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 0f8ddfc65b..bf5b186b80 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -27,8 +27,8 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants -from tensorflow.python.saved_model import utils from tensorflow.python.util import compat @@ -315,7 +315,8 @@ class SavedModelTest(tf.test.TestCase): with self.test_session(graph=tf.Graph()) as sess: self._init_and_validate_variable(sess, "v", 42) # Build and populate an empty SignatureDef for testing. - foo_signature = utils.build_signature_def(dict(), dict(), "foo") + foo_signature = signature_def_utils.build_signature_def( + dict(), dict(), "foo") builder.add_meta_graph_and_variables( sess, ["foo"], signature_def_map={"foo_key": foo_signature}) @@ -324,10 +325,12 @@ class SavedModelTest(tf.test.TestCase): with self.test_session(graph=tf.Graph()) as sess: self._init_and_validate_variable(sess, "v", 43) # Build and populate a different SignatureDef for testing. - bar_signature = utils.build_signature_def(dict(), dict(), "bar") + bar_signature = signature_def_utils.build_signature_def( + dict(), dict(), "bar") # Also, build a different SignatureDef corresponding to "foo_key" defined # in the previous graph. - foo_new_signature = utils.build_signature_def(dict(), dict(), "foo_new") + foo_new_signature = signature_def_utils.build_signature_def( + dict(), dict(), "foo_new") builder.add_meta_graph( ["bar"], signature_def_map={"bar_key": bar_signature, diff --git a/tensorflow/python/saved_model/signature_def_utils.py b/tensorflow/python/saved_model/signature_def_utils.py new file mode 100644 index 0000000000..23e844adb2 --- /dev/null +++ b/tensorflow/python/saved_model/signature_def_utils.py @@ -0,0 +1,158 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""SignatureDef utility functions. + +Utility functions for constructing SignatureDef protos. +""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import utils + + +def build_signature_def(inputs=None, outputs=None, method_name=None): + """Utility function to build a SignatureDef protocol buffer. + + Args: + inputs: Inputs of the SignatureDef defined as a proto map of string to + tensor info. + outputs: Outputs of the SignatureDef defined as a proto map of string to + tensor info. + method_name: Method name of the SignatureDef as a string. + + Returns: + A SignatureDef protocol buffer constructed based on the supplied arguments. + """ + signature_def = meta_graph_pb2.SignatureDef() + if inputs is not None: + for item in inputs: + signature_def.inputs[item].CopyFrom(inputs[item]) + if outputs is not None: + for item in outputs: + signature_def.outputs[item].CopyFrom(outputs[item]) + if method_name is not None: + signature_def.method_name = method_name + return signature_def + + +def regression_signature_def(examples, predictions): + """Creates regression signature from given examples and predictions. + + Args: + examples: `Tensor`. + predictions: `Tensor`. + + Returns: + A regression-flavored signature_def. + + Raises: + ValueError: If examples is `None`. + """ + if examples is None: + raise ValueError('examples cannot be None for regression.') + if predictions is None: + raise ValueError('predictions cannot be None for regression.') + + input_tensor_info = utils.build_tensor_info(examples) + signature_inputs = {signature_constants.REGRESS_INPUTS: input_tensor_info} + + output_tensor_info = utils.build_tensor_info(predictions) + signature_outputs = {signature_constants.REGRESS_OUTPUTS: output_tensor_info} + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.REGRESS_METHOD_NAME) + + return signature_def + + +def classification_signature_def(examples, classes, scores): + """Creates classification signature from given examples and predictions. + + Args: + examples: `Tensor`. + classes: `Tensor`. + scores: `Tensor`. + + Returns: + A classification-flavored signature_def. + + Raises: + ValueError: If examples is `None`. + """ + if examples is None: + raise ValueError('examples cannot be None for classification.') + if classes is None and scores is None: + raise ValueError('classes and scores cannot both be None for ' + 'classification.') + + input_tensor_info = utils.build_tensor_info(examples) + signature_inputs = {signature_constants.CLASSIFY_INPUTS: input_tensor_info} + + signature_outputs = {} + if classes is not None: + classes_tensor_info = utils.build_tensor_info(classes) + signature_outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES] = ( + classes_tensor_info) + if scores is not None: + scores_tensor_info = utils.build_tensor_info(scores) + signature_outputs[signature_constants.CLASSIFY_OUTPUT_SCORES] = ( + scores_tensor_info) + + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.CLASSIFY_METHOD_NAME) + + return signature_def + + +def predict_signature_def(inputs, outputs): + """Creates prediction signature from given inputs and outputs. + + Args: + inputs: dict of string to `Tensor`. + outputs: dict of string to `Tensor`. + + Returns: + A prediction-flavored signature_def. + + Raises: + ValueError: If inputs or outputs is `None`. + """ + if inputs is None or not inputs: + raise ValueError('inputs cannot be None or empty for prediction.') + if outputs is None: + raise ValueError('outputs cannot be None or empty for prediction.') + + # If there's only one input or output, we can standardize keys + if len(inputs) == 1: + (_, value), = inputs.items() + inputs = {signature_constants.PREDICT_INPUTS: value} + if len(outputs) == 1: + (_, value), = outputs.items() + outputs = {signature_constants.PREDICT_OUTPUTS: value} + + signature_inputs = {key: utils.build_tensor_info(tensor) + for key, tensor in inputs.items()} + signature_outputs = {key: utils.build_tensor_info(tensor) + for key, tensor in outputs.items()} + + signature_def = build_signature_def( + signature_inputs, signature_outputs, + signature_constants.PREDICT_METHOD_NAME) + + return signature_def diff --git a/tensorflow/python/saved_model/signature_def_utils_test.py b/tensorflow/python/saved_model/signature_def_utils_test.py new file mode 100644 index 0000000000..6dfc4b2cd6 --- /dev/null +++ b/tensorflow/python/saved_model/signature_def_utils_test.py @@ -0,0 +1,156 @@ +# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for SignatureDef utils.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.core.framework import types_pb2 +from tensorflow.python.saved_model import signature_constants +from tensorflow.python.saved_model import signature_def_utils +from tensorflow.python.saved_model import utils + + +class SignatureDefUtilsTest(tf.test.TestCase): + + def testBuildSignatureDef(self): + x = tf.placeholder(tf.float32, 1, name="x") + x_tensor_info = utils.build_tensor_info(x) + inputs = dict() + inputs["foo-input"] = x_tensor_info + + y = tf.placeholder(tf.float32, name="y") + y_tensor_info = utils.build_tensor_info(y) + outputs = dict() + outputs["foo-output"] = y_tensor_info + + signature_def = signature_def_utils.build_signature_def( + inputs, outputs, "foo-method-name") + self.assertEqual("foo-method-name", signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = signature_def.inputs["foo-input"] + self.assertEqual("x:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) + self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) + self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + y_tensor_info_actual = signature_def.outputs["foo-output"] + self.assertEqual("y:0", y_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) + self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) + + def testRegressionSignatureDef(self): + input1 = tf.constant("a", name="input-1") + output1 = tf.constant("b", name="output-1") + signature_def = signature_def_utils.regression_signature_def( + input1, output1) + + self.assertEqual(signature_constants.REGRESS_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = ( + signature_def.inputs[signature_constants.REGRESS_INPUTS]) + self.assertEqual("input-1:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) + self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(1, len(signature_def.outputs)) + y_tensor_info_actual = ( + signature_def.outputs[signature_constants.REGRESS_OUTPUTS]) + self.assertEqual("output-1:0", y_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, y_tensor_info_actual.dtype) + self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) + + def testClassificationSignatureDef(self): + input1 = tf.constant("a", name="input-1") + output1 = tf.constant("b", name="output-1") + output2 = tf.constant("c", name="output-2") + signature_def = signature_def_utils.classification_signature_def( + input1, output1, output2) + + self.assertEqual(signature_constants.CLASSIFY_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(1, len(signature_def.inputs)) + x_tensor_info_actual = ( + signature_def.inputs[signature_constants.CLASSIFY_INPUTS]) + self.assertEqual("input-1:0", x_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, x_tensor_info_actual.dtype) + self.assertEqual(0, len(x_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + classes_tensor_info_actual = ( + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_CLASSES]) + self.assertEqual("output-1:0", classes_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, classes_tensor_info_actual.dtype) + self.assertEqual(0, len(classes_tensor_info_actual.tensor_shape.dim)) + scores_tensor_info_actual = ( + signature_def.outputs[signature_constants.CLASSIFY_OUTPUT_SCORES]) + self.assertEqual("output-2:0", scores_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, scores_tensor_info_actual.dtype) + self.assertEqual(0, len(scores_tensor_info_actual.tensor_shape.dim)) + + def testPredictionSignatureDef(self): + input1 = tf.constant("a", name="input-1") + input2 = tf.constant("b", name="input-2") + output1 = tf.constant("c", name="output-1") + output2 = tf.constant("d", name="output-2") + signature_def = signature_def_utils.predict_signature_def( + {"input-1": input1, "input-2": input2}, + {"output-1": output1, "output-2": output2}) + + self.assertEqual(signature_constants.PREDICT_METHOD_NAME, + signature_def.method_name) + + # Check inputs in signature def. + self.assertEqual(2, len(signature_def.inputs)) + input1_tensor_info_actual = ( + signature_def.inputs["input-1"]) + self.assertEqual("input-1:0", input1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input1_tensor_info_actual.dtype) + self.assertEqual(0, len(input1_tensor_info_actual.tensor_shape.dim)) + input2_tensor_info_actual = ( + signature_def.inputs["input-2"]) + self.assertEqual("input-2:0", input2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, input2_tensor_info_actual.dtype) + self.assertEqual(0, len(input2_tensor_info_actual.tensor_shape.dim)) + + # Check outputs in signature def. + self.assertEqual(2, len(signature_def.outputs)) + output1_tensor_info_actual = ( + signature_def.outputs["output-1"]) + self.assertEqual("output-1:0", output1_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, output1_tensor_info_actual.dtype) + self.assertEqual(0, len(output1_tensor_info_actual.tensor_shape.dim)) + output2_tensor_info_actual = ( + signature_def.outputs["output-2"]) + self.assertEqual("output-2:0", output2_tensor_info_actual.name) + self.assertEqual(types_pb2.DT_STRING, output2_tensor_info_actual.dtype) + self.assertEqual(0, len(output2_tensor_info_actual.tensor_shape.dim)) + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py index 550eed0fcc..ecc58fbc7a 100644 --- a/tensorflow/python/saved_model/utils.py +++ b/tensorflow/python/saved_model/utils.py @@ -23,6 +23,7 @@ from __future__ import print_function from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes + # TensorInfo helpers. @@ -40,30 +41,3 @@ def build_tensor_info(tensor): name=tensor.name, dtype=dtype_enum, tensor_shape=tensor.get_shape().as_proto()) - -# SignatureDef helpers. - - -def build_signature_def(inputs=None, outputs=None, method_name=None): - """Utility function to build a SignatureDef protocol buffer. - - Args: - inputs: Inputs of the SignatureDef defined as a proto map of string to - tensor info. - outputs: Outputs of the SignatureDef defined as a proto map of string to - tensor info. - method_name: Method name of the SignatureDef as a string. - - Returns: - A SignatureDef protocol buffer constructed based on the supplied arguments. - """ - signature_def = meta_graph_pb2.SignatureDef() - if inputs is not None: - for item in inputs: - signature_def.inputs[item].CopyFrom(inputs[item]) - if outputs is not None: - for item in outputs: - signature_def.outputs[item].CopyFrom(outputs[item]) - if method_name is not None: - signature_def.method_name = method_name - return signature_def diff --git a/tensorflow/python/saved_model/utils_test.py b/tensorflow/python/saved_model/utils_test.py index 8ce7d1dea1..74f2624773 100644 --- a/tensorflow/python/saved_model/utils_test.py +++ b/tensorflow/python/saved_model/utils_test.py @@ -34,36 +34,6 @@ class UtilsTest(tf.test.TestCase): self.assertEqual(1, len(x_tensor_info.tensor_shape.dim)) self.assertEqual(1, x_tensor_info.tensor_shape.dim[0].size) - def testBuildSignatureDef(self): - x = tf.placeholder(tf.float32, 1, name="x") - x_tensor_info = utils.build_tensor_info(x) - inputs = dict() - inputs["foo-input"] = x_tensor_info - - y = tf.placeholder(tf.float32, name="y") - y_tensor_info = utils.build_tensor_info(y) - outputs = dict() - outputs["foo-output"] = y_tensor_info - - signature_def = utils.build_signature_def(inputs, outputs, - "foo-method-name") - self.assertEqual("foo-method-name", signature_def.method_name) - - # Check inputs in signature def. - self.assertEqual(1, len(signature_def.inputs)) - x_tensor_info_actual = signature_def.inputs["foo-input"] - self.assertEqual("x:0", x_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_FLOAT, x_tensor_info_actual.dtype) - self.assertEqual(1, len(x_tensor_info_actual.tensor_shape.dim)) - self.assertEqual(1, x_tensor_info_actual.tensor_shape.dim[0].size) - - # Check outputs in signature def. - self.assertEqual(1, len(signature_def.outputs)) - y_tensor_info_actual = signature_def.outputs["foo-output"] - self.assertEqual("y:0", y_tensor_info_actual.name) - self.assertEqual(types_pb2.DT_FLOAT, y_tensor_info_actual.dtype) - self.assertEqual(0, len(y_tensor_info_actual.tensor_shape.dim)) - if __name__ == "__main__": tf.test.main() |