aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/BUILD2
-rw-r--r--tensorflow/contrib/session_bundle/BUILD166
-rw-r--r--tensorflow/contrib/session_bundle/README.md243
-rw-r--r--tensorflow/contrib/session_bundle/example/BUILD52
-rw-r--r--tensorflow/contrib/session_bundle/example/export_half_plus_two.py115
-rw-r--r--tensorflow/contrib/session_bundle/exporter.py311
-rw-r--r--tensorflow/contrib/session_bundle/exporter_test.py218
-rw-r--r--tensorflow/contrib/session_bundle/gc.py204
-rw-r--r--tensorflow/contrib/session_bundle/gc_test.py115
-rw-r--r--tensorflow/contrib/session_bundle/manifest.proto70
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.cc181
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle.h59
-rw-r--r--tensorflow/contrib/session_bundle/session_bundle_test.cc102
-rw-r--r--tensorflow/contrib/session_bundle/signature.cc270
-rw-r--r--tensorflow/contrib/session_bundle/signature.h123
-rw-r--r--tensorflow/contrib/session_bundle/signature_test.cc602
-rw-r--r--tensorflow/contrib/session_bundle/test_util.cc35
-rw-r--r--tensorflow/contrib/session_bundle/test_util.h38
-rw-r--r--tensorflow/tools/pip_package/BUILD3
-rw-r--r--tools/bazel.rc.template6
20 files changed, 2915 insertions, 0 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index a8153c06b6..9e5920efc4 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -84,6 +84,8 @@ filegroup(
"//tensorflow/contrib/quantization:all_files",
"//tensorflow/contrib/quantization/kernels:all_files",
"//tensorflow/contrib/quantization/tools:all_files",
+ "//tensorflow/contrib/session_bundle:all_files",
+ "//tensorflow/contrib/session_bundle/example:all_files",
"//tensorflow/contrib/skflow:all_files",
"//tensorflow/contrib/slim:all_files",
"//tensorflow/contrib/tensor_forest:all_files",
diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD
new file mode 100644
index 0000000000..19a10b0fc0
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/BUILD
@@ -0,0 +1,166 @@
+# Description: Tensorflow Serving session bundle.
+
+package(
+ default_visibility = ["//visibility:public"],
+ features = [
+ "-layering_check",
+ ],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ "g3doc/sitemap.md",
+ ],
+ ),
+)
+
+load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library")
+
+py_library(
+ name = "exporter",
+ srcs = ["exporter.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gc",
+ ":manifest_proto_py",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "exporter_test",
+ size = "small",
+ srcs = [
+ "exporter_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":exporter",
+ ":gc",
+ ":manifest_proto_py",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_library(
+ name = "gc",
+ srcs = ["gc.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+py_test(
+ name = "gc_test",
+ srcs = [
+ "gc_test.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":gc",
+ "//tensorflow:tensorflow_py",
+ ],
+)
+
+cc_library(
+ name = "session_bundle",
+ srcs = ["session_bundle.cc"],
+ hdrs = ["session_bundle.h"],
+ deps = [
+ ":manifest_proto_cc",
+ ":signature",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+cc_test(
+ name = "session_bundle_test",
+ size = "small",
+ srcs = ["session_bundle_test.cc"],
+ data = [
+ "//tensorflow/contrib/session_bundle/example:half_plus_two",
+ ],
+ # Link in all registered kernels.
+ linkstatic = 1,
+ visibility = ["//visibility:private"],
+ deps = [
+ ":session_bundle",
+ ":test_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "signature",
+ srcs = ["signature.cc"],
+ hdrs = ["signature.h"],
+ deps = [
+ ":manifest_proto_cc",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:tensorflow_opensource",
+ ],
+)
+
+cc_test(
+ name = "signature_test",
+ size = "small",
+ srcs = ["signature_test.cc"],
+ visibility = ["//visibility:private"],
+ deps = [
+ ":manifest_proto_cc",
+ ":signature",
+ ":test_util",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
+ "//tensorflow/core:tensorflow_opensource",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ "//tensorflow/core:testlib",
+ ],
+)
+
+cc_library(
+ name = "test_util",
+ testonly = 1,
+ srcs = ["test_util.cc"],
+ hdrs = ["test_util.h"],
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/core",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ ],
+)
+
+tf_proto_library(
+ name = "manifest_proto",
+ srcs = ["manifest.proto"],
+ cc_api_version = 2,
+ py_api_version = 2,
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/session_bundle/README.md b/tensorflow/contrib/session_bundle/README.md
new file mode 100644
index 0000000000..d9b6c0f9a2
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/README.md
@@ -0,0 +1,243 @@
+# TensorFlow Inference Model Format
+
+[TOC]
+
+## Overview
+
+This document describes the data formats and layouts for exporting [TensorFlow]
+(https://www.tensorflow.org/) models for inference.
+
+These exports have the following properties,
+
+* Recoverable
+ * given an export the graph can easily be initialized and run
+* Hermetic
+ * an export directory is self-contained to facilitate distribution
+
+## Directory Structure
+
+~~~
+# Directory overview
+00000000/
+ assets/
+ export.meta
+ export-?????-of-?????
+~~~
+
+* `00000000` Export version
+ * Format `%08d`
+* `assets` Asset file directory
+ * Holds auxiliary files for the graph (e.g., vocabularies)
+* `export.meta` MetaGraph Definition
+ * Binary [`tensorflow::MetaGraphDef`]
+ (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/protobuf/meta_graph.proto)
+* `export-?????-of-?????`
+ * Graph Variables
+ * Outputs from Python [`Saver`]
+ (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py)
+ with `sharded=True`.
+
+## Python exporting code
+
+The [`Exporter`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/exporter.py)
+class can be used to export a model in the above format from a Tensorflow python
+binary.
+
+## C++ initialization code
+
+The [`LoadSessionBundleFromPath`]
+(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/session_bundle.h)
+function can be used to create a `tensorflow::Session` and initialize it from an
+export. This function takes options and the path to the export and returns a
+bundle of export data including a `tensorflow::Session` which can be run.
+
+## Signatures
+
+Graphs used for inference tasks typically have set of inputs and outputs used
+at inference time. We call this a signature.
+
+### Standard Signatures (standard usage)
+
+Graphs used for standard inference tasks have standard set of inputs and
+outputs. For example, a graph used for a regression task has an input tensor for
+the data and an output tensor for the regression values. The signature mechanism
+makes it easy to identify the relevant input and output tensors for common graph
+applications.
+
+The [`Manifest`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/manifest.proto)
+contains a `Signature` message which contains the task specific inputs and
+outputs.
+
+~~~
+// A Signature specifies the inputs and outputs of commonly used graphs.
+message Signature {
+ oneof type {
+ RegressionSignature regression_signature = 1;
+ ClassificationSignature classification_signature = 2;
+ GenericSignature generic_signature = 3;
+ }
+};
+~~~
+
+Standard signature can be set at export time using the `Exporter` API
+
+~~~python
+# Run an export.
+signature = exporter.classification_signature(input_tensor=input,
+ classes_tensor=output)
+export = exporter.Exporter(saver)
+export.init(sess.graph.as_graph_def(),
+ default_graph_signature=signature)
+export.export(export_path,
+ global_step_tensor,
+ sess)
+~~~
+
+These can be recovered at serving time using utilities in [`signature.h`]
+(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h)
+
+~~~c++
+// Get the a classification signature.
+ClassificationSignature signature;
+TF_CHECK_OK(GetClassificationSignature(bundle->meta_graph_def, &signature));
+
+// Run the graph.
+Tensor input_tensor = GetInputTensor();
+Tensor classes_tensor;
+Tensor scores_tensor;
+TF_CHECK_OK(RunClassification(signature, input_tensor, session,
+ &classes_tensor, &scores_tensor));
+~~~
+
+### Generic Signatures (custom or advanced usage)
+
+Generic Signatures enable fully custom usage of the `tensorflow::Session` API.
+They are recommended for when the standard Signatures do not satisfy a
+particular use-case. A general example of when to use these is for a model
+taking a single input and generating multiple outputs performing different
+inferences.
+
+~~~
+// GenericSignature specifies a map from logical name to Tensor name.
+// Typical application of GenericSignature is to use a single GenericSignature
+// that includes all of the Tensor nodes and target names that may be useful at
+// serving, analysis or debugging time. The recommended name for this signature
+// is "generic_bindings".
+message GenericSignature {
+ map<string, TensorBinding> map = 1;
+};
+~~~
+
+Generic Signatures can be used to compliment a standard signature, for example
+to support debugging. Here is an example usage including both the standard
+regression signature and a generic signature.
+
+~~~python
+named_tensor_bindings = {"logical_input_A": v0,
+ "logical_input_B": v1}
+signatures = {
+ "regression": exporter.regression_signature(input_tensor=v0,
+ output_tensor=v1),
+ "generic": exporter.generic_signature(named_tensor_bindings)}
+export = exporter.Exporter(saver)
+export.init(sess.graph.as_graph_def(),
+ named_graph_signature=signatures)
+export.export(export_path,
+ global_step_tensor,
+ sess)
+~~~
+
+Generic Signature does not differentiate between input and output tensors. It
+provides full flexibility to specify the input & output tensors you need.
+The benefit is preserving a mapping between names that you specify at export
+time (we call these the logical names), and the actual graph node names that may
+be less stable and/or auto-generated by TensorFlow.
+
+In [`signature.h`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h),
+note that the generic signature methods BindGenericInputs and BindGenericNames
+are doing simple string to string mapping as a convenience. These methods map
+from the names used at training time to actual names in the graph. Use the bound
+results from those methods, e.g. `vector<pair<string, Tensor>>` and
+`vector<string>` respectively, as inputs to`tensorflow::Session->Run()`. For
+`Session->Run()`, map these into the first two parameters, `inputs` and
+`output_tensor_names` respectively. The next param, `target_node_names` is
+typically null at inference time. The last param outputs is for the results in
+the same order of your `output_tensor_names`.
+
+## Initialization
+
+Some graphs many require custom initialization after the variables have been
+restored. Such initialization, done through an arbitrary Op, can be added using
+the `Exporter` API. If set, `LoadSessionBundleFromPath` will automatically run
+the Op when restoring a `Session` following the loading of variables.
+
+## Assets
+
+In many cases we have Ops which depend on external files for initialization
+(such as vocabularies). These "assets" are not stored in the graph and are
+needed for both training and inference.
+
+In order to create hermetic exports these asset files need to be 1) copied to
+each export directory and 2) read when recovering a session from an export base
+directory.
+
+Copying assets to the export dir is handled with a callback mechanism.
+The callback function receives two parameters 1) the dictionary of source files
+to desired basename and 2) the export directory. The default callback uses
+`gfile.Copy` to perform the copy.
+
+The tensors that contains the filepath to be copied and be replaced for
+inference in specified by passing the collection of asset filepath tensor,
+which is usually extracted from the graph by `tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)`.
+
+~~~python
+ # Run an export.
+ export = exporter.Exporter(save)
+ export.init(
+ sess.graph.as_graph_def(),
+ asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))
+ export.export(export_path, global_step_tensor, sess)
+~~~
+
+Users can use their own callbacks as shown in the following example, with the
+requirement to keep the basename of the original files:
+
+~~~python
+ def my_custom_copy_callback(files_to_copy, export_dir_path):
+ # Copy all source files (keys) in files_to_copy to export_dir_path
+ # using the corresponging basename (value).
+ ...
+
+
+ # Run an export.
+ export = exporter.Exporter(save)
+ export.init(
+ sess.graph.as_graph_def(),
+ asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
+ asset_callback=my_custom_copy_callback)
+ export.export(export_path, global_step_tensor, sess)
+~~~
+
+
+`AssetFile` binds the name of a tensor in the graph to the name of a file
+within the assets directory. `LoadSessionBundleFromPath` will handle the base
+path and asset directory swap/concatenation such that the tensor is set with
+the fully qualified filename upon return.
+
+# Notes of exporter usage
+
+The typical workflow of model exporting is:
+
+1. Build model graph G
+2. Train variables or load trained variables from checkpoint in session S
+3. [Optional] build inference graph I
+4. Export G
+
+The Exporter should be used as follows:
+
+1. The Saver used in Exporter(saver) should be created under the context of G
+2. Exporter.init() should be called under the context of G
+3. Exporter.export() should be called using session S
+4. If I is provided for Exporter.init(), an exact same Saver should be created
+ under I as the saver under G -- in the way that exact same Save/Restore ops
+ are created in both G and S
diff --git a/tensorflow/contrib/session_bundle/example/BUILD b/tensorflow/contrib/session_bundle/example/BUILD
new file mode 100644
index 0000000000..8fa0c0b020
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/example/BUILD
@@ -0,0 +1,52 @@
+# Description: Tensorflow Serving session_bundle example.
+
+package(
+ default_visibility = ["//tensorflow/contrib/session_bundle:__subpackages__"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+# vardef("PYTHON_BIN_PATH", "/usr/bin/python")
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ "g3doc/sitemap.md",
+ ],
+ ),
+ visibility = ["//visibility:public"],
+)
+
+py_binary(
+ name = "export_half_plus_two",
+ srcs = [
+ "export_half_plus_two.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/session_bundle:exporter",
+ ],
+)
+
+genrule(
+ name = "half_plus_two",
+ outs = [
+ "half_plus_two/00000123/export.meta",
+ "half_plus_two/00000123/export-00000-of-00001",
+ ],
+ cmd =
+ "rm -rf /tmp/half_plus_two; " +
+ "$(PYTHON_BIN_PATH) $(locations :export_half_plus_two); " +
+ "cp -r /tmp/half_plus_two/* $(@D)/half_plus_two",
+ tools = [
+ ":export_half_plus_two",
+ ],
+ visibility = ["//visibility:public"],
+)
diff --git a/tensorflow/contrib/session_bundle/example/export_half_plus_two.py b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
new file mode 100644
index 0000000000..e4b1947e03
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py
@@ -0,0 +1,115 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Exports a toy linear regression inference graph.
+
+Exports a TensorFlow graph to /tmp/half_plus_two/ based on the Exporter
+format, go/tf-exporter.
+
+This graph calculates,
+ y = a*x + b
+where a and b are variables with a=0.5 and b=2.
+
+Output from this program is typically used to exercise Session
+loading and execution code.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+from tensorflow.contrib.session_bundle import exporter
+
+
+def Export():
+ export_path = "/tmp/half_plus_two"
+ with tf.Session() as sess:
+ # Make model parameters a&b variables instead of constants to
+ # exercise the variable reloading mechanisms.
+ a = tf.Variable(0.5, name="a")
+ b = tf.Variable(2.0, name="b")
+
+ # Calculate, y = a*x + b
+ # here we use a placeholder 'x' which is fed at inference time.
+ x = tf.placeholder(tf.float32, name="x")
+ y = tf.add(tf.mul(a, x), b, name="y")
+
+ # Setup a standard Saver for our variables.
+ save = tf.train.Saver({"a": a, "b": b}, sharded=True)
+
+ # asset_path contains the base directory of assets used in training (e.g.
+ # vocabulary files).
+ original_asset_path = tf.constant("/tmp/original/export/assets")
+ # Ops reading asset files should reference the asset_path tensor
+ # which stores the original asset path at training time and the
+ # overridden assets directory at restore time.
+ asset_path = tf.Variable(original_asset_path,
+ name="asset_path",
+ trainable=False,
+ collections=[])
+ assign_asset_path = asset_path.assign(original_asset_path)
+
+ # Use a fixed global step number.
+ global_step_tensor = tf.Variable(123, name="global_step")
+
+ # Create a RegressionSignature for our input and output.
+ signature = exporter.regression_signature(input_tensor=x, output_tensor=y)
+
+ # Create two filename assets and corresponding tensors.
+ # TODO(b/26254158) Consider adding validation of file existance as well as
+ # hashes (e.g. sha1) for consistency.
+ original_filename1 = tf.constant("hello1.txt")
+ tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename1)
+ filename1 = tf.Variable(original_filename1,
+ name="filename1",
+ trainable=False,
+ collections=[])
+ assign_filename1 = filename1.assign(original_filename1)
+ original_filename2 = tf.constant("hello2.txt")
+ tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename2)
+ filename2 = tf.Variable(original_filename2,
+ name="filename2",
+ trainable=False,
+ collections=[])
+ assign_filename2 = filename2.assign(original_filename2)
+
+ # Init op contains a group of all variables that we assign.
+ init_op = tf.group(assign_asset_path, assign_filename1, assign_filename2)
+
+ # CopyAssets is used as a callback during export to copy files to the
+ # given export directory.
+ def CopyAssets(filepaths, export_path):
+ print("copying asset files to: %s" % export_path)
+ for filepath in filepaths:
+ print("copying asset file: %s" % filepath)
+
+ # Run an export.
+ tf.initialize_all_variables().run()
+ export = exporter.Exporter(save)
+ export.init(
+ sess.graph.as_graph_def(),
+ init_op=init_op,
+ default_graph_signature=signature,
+ assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
+ assets_callback=CopyAssets)
+ export.export(export_path, global_step_tensor, sess)
+
+
+def main(_):
+ Export()
+
+
+if __name__ == "__main__":
+ tf.app.run()
diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py
new file mode 100644
index 0000000000..fc7458db62
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/exporter.py
@@ -0,0 +1,311 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Export a TensorFlow model.
+
+See: go/tf-exporter
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+import six
+
+import tensorflow as tf
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.contrib.session_bundle import gc
+from tensorflow.contrib.session_bundle import manifest_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.training import training_util
+from tensorflow.python.util import compat
+
+# See: go/tf-exporter for these constants and directory structure.
+VERSION_FORMAT_SPECIFIER = "%08d"
+ASSETS_DIRECTORY = "assets"
+EXPORT_BASE_NAME = "export"
+EXPORT_SUFFIX_NAME = "meta"
+META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME
+VARIABLES_FILENAME = EXPORT_BASE_NAME
+VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????"
+INIT_OP_KEY = "serving_init_op"
+SIGNATURES_KEY = "serving_signatures"
+ASSETS_KEY = "serving_assets"
+GRAPH_KEY = "serving_graph"
+
+
+def gfile_copy_callback(files_to_copy, export_dir_path):
+ """Callback to copy files using `gfile.Copy` to an export directory.
+
+ This method is used as the default `assets_callback` in `Exporter.init` to
+ copy assets from the `assets_collection`. It can also be invoked directly to
+ copy additional supplementary files into the export directory (in which case
+ it is not a callback).
+
+ Args:
+ files_to_copy: A dictionary that maps original file paths to desired
+ basename in the export directory.
+ export_dir_path: Directory to copy the files to.
+ """
+ tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
+ gfile.MakeDirs(export_dir_path)
+ for source_filepath, basename in files_to_copy.items():
+ new_path = os.path.join(
+ compat.as_bytes(export_dir_path), compat.as_bytes(basename))
+ tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
+
+ if gfile.Exists(new_path):
+ # Guard against being restarted while copying assets, and the file
+ # existing and being in an unknown state.
+ # TODO(b/28676216): Do some file checks before deleting.
+ tf.logging.info("Removing file %s.", new_path)
+ gfile.Remove(new_path)
+ tf.gfile.Copy(source_filepath, new_path)
+
+
+def regression_signature(input_tensor, output_tensor):
+ """Creates a regression signature.
+
+ Args:
+ input_tensor: Tensor specifying the input to a graph.
+ output_tensor: Tensor specifying the output of a graph.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ signature.regression_signature.input.tensor_name = input_tensor.name
+ signature.regression_signature.output.tensor_name = output_tensor.name
+ return signature
+
+
+def classification_signature(input_tensor,
+ classes_tensor=None,
+ scores_tensor=None):
+ """Creates a classification signature.
+
+ Args:
+ input_tensor: Tensor specifying the input to a graph.
+ classes_tensor: Tensor specifying the output classes of a graph.
+ scores_tensor: Tensor specifying the scores of the output classes.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ signature.classification_signature.input.tensor_name = input_tensor.name
+ if classes_tensor is not None:
+ signature.classification_signature.classes.tensor_name = classes_tensor.name
+ if scores_tensor is not None:
+ signature.classification_signature.scores.tensor_name = scores_tensor.name
+ return signature
+
+
+def generic_signature(name_tensor_map):
+ """Creates a generic signature of name to Tensor name.
+
+ Args:
+ name_tensor_map: Map from logical name to Tensor.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ for name, tensor in six.iteritems(name_tensor_map):
+ signature.generic_signature.map[name].tensor_name = tensor.name
+ return signature
+
+
+class Exporter(object):
+ """Exporter helps package a TensorFlow model for serving.
+
+ Args:
+ saver: Saver object.
+ """
+
+ def __init__(self, saver):
+ self._saver = saver
+ self._has_init = False
+ self._assets_to_copy = {}
+
+ def init(self,
+ graph_def=None,
+ init_op=None,
+ clear_devices=False,
+ default_graph_signature=None,
+ named_graph_signatures=None,
+ assets_collection=None,
+ assets_callback=gfile_copy_callback):
+ """Initialization.
+
+ Args:
+ graph_def: A GraphDef message of the graph to be used in inference.
+ GraphDef of default graph is used when None.
+ init_op: Op to be used in initialization.
+ clear_devices: If device info of the graph should be cleared upon export.
+ default_graph_signature: Default signature of the graph.
+ named_graph_signatures: Map of named input/output signatures of the graph.
+ assets_collection: A collection of constant asset filepath tensors. If set
+ the assets will be exported into the asset directory.
+ assets_callback: callback with two argument called during export with the
+ list of files to copy and the asset path.
+ Raises:
+ RuntimeError: if init is called more than once.
+ TypeError: if init_op is not an Operation or None.
+ ValueError: if asset file path tensors are not non-empty constant string
+ scalar tensors.
+ """
+ # Avoid Dangerous default value []
+ if named_graph_signatures is None:
+ named_graph_signatures = {}
+ assets = []
+ if assets_collection:
+ for asset_tensor in assets_collection:
+ asset_filepath = self._file_path_value(asset_tensor)
+ if not asset_filepath:
+ raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
+ basename = os.path.basename(asset_filepath)
+ assets.append((basename, asset_tensor))
+ self._assets_to_copy[asset_filepath] = basename
+
+ if self._has_init:
+ raise RuntimeError("init should be called only once")
+ self._has_init = True
+
+ if graph_def or clear_devices:
+ copy = tf.GraphDef()
+ if graph_def:
+ copy.CopyFrom(graph_def)
+ else:
+ copy.CopyFrom(tf.get_default_graph().as_graph_def())
+ if clear_devices:
+ for node in copy.node:
+ node.device = ""
+ graph_any_buf = Any()
+ graph_any_buf.Pack(copy)
+ tf.add_to_collection(GRAPH_KEY, graph_any_buf)
+
+ if init_op:
+ if not isinstance(init_op, ops.Operation):
+ raise TypeError("init_op needs to be an Operation: %s" % init_op)
+ tf.add_to_collection(INIT_OP_KEY, init_op)
+
+ signatures_proto = manifest_pb2.Signatures()
+ if default_graph_signature:
+ signatures_proto.default_signature.CopyFrom(default_graph_signature)
+ for signature_name, signature in six.iteritems(named_graph_signatures):
+ signatures_proto.named_signatures[signature_name].CopyFrom(signature)
+ signatures_any_buf = Any()
+ signatures_any_buf.Pack(signatures_proto)
+ tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
+
+ for filename, tensor in assets:
+ asset = manifest_pb2.AssetFile()
+ asset.filename = filename
+ asset.tensor_binding.tensor_name = tensor.name
+ asset_any_buf = Any()
+ asset_any_buf.Pack(asset)
+ tf.add_to_collection(ASSETS_KEY, asset_any_buf)
+
+ self._assets_callback = assets_callback
+
+ def export(self,
+ export_dir_base,
+ global_step_tensor,
+ sess=None,
+ exports_to_keep=None):
+ """Exports the model.
+
+ Args:
+ export_dir_base: A string path to the base export dir.
+ global_step_tensor: An Tensor or tensor name providing the
+ global step counter to append to the export directory path and set
+ in the manifest version.
+ sess: A Session to use to save the parameters.
+ exports_to_keep: a gc.Path filter function used to determine the set of
+ exports to keep. If set to None, all versions will be kept.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ RuntimeError: if init is not called.
+ RuntimeError: if the export would overwrite an existing directory.
+ """
+ if not self._has_init:
+ raise RuntimeError("init must be called first")
+
+ global_step = training_util.global_step(sess, global_step_tensor)
+ export_dir = os.path.join(
+ compat.as_bytes(export_dir_base),
+ compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step))
+
+ # Prevent overwriting on existing exports which could lead to bad/corrupt
+ # storage and loading of models. This is an important check that must be
+ # done before any output files or directories are created.
+ if gfile.Exists(export_dir):
+ raise RuntimeError("Overwriting exports can cause corruption and are "
+ "not allowed. Duplicate export dir: %s" % export_dir)
+
+ # Output to a temporary directory which is atomically renamed to the final
+ # directory when complete.
+ tmp_export_dir = compat.as_text(export_dir) + "-tmp"
+ gfile.MakeDirs(tmp_export_dir)
+
+ self._saver.save(sess,
+ os.path.join(
+ compat.as_text(tmp_export_dir),
+ compat.as_text(EXPORT_BASE_NAME)),
+ meta_graph_suffix=EXPORT_SUFFIX_NAME)
+
+ # Run the asset callback.
+ if self._assets_callback and self._assets_to_copy:
+ assets_dir = os.path.join(
+ compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY))
+ gfile.MakeDirs(assets_dir)
+ self._assets_callback(self._assets_to_copy, assets_dir)
+
+ # TODO(b/27794910): Delete *checkpoint* file before rename.
+ gfile.Rename(tmp_export_dir, export_dir)
+
+ if exports_to_keep:
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ paths_to_delete = gc.negation(exports_to_keep)
+ for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
+ gfile.DeleteRecursively(p.path)
+
+ return export_dir
+
+ def _file_path_value(self, path_tensor):
+ """Returns the filepath value stored in constant `path_tensor`."""
+ if not isinstance(path_tensor, tf.Tensor):
+ raise TypeError("tensor is not a Tensor")
+ if path_tensor.op.type != "Const":
+ raise TypeError("Only constants tensor are supported")
+ if path_tensor.dtype != tf.string:
+ raise TypeError("File paths should be string")
+ str_value = path_tensor.op.get_attr("value").string_val
+ if len(str_value) != 1:
+ raise TypeError("Only scalar tensors are supported")
+ return str_value[0]
diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py
new file mode 100644
index 0000000000..8b54933fc9
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/exporter_test.py
@@ -0,0 +1,218 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for exporter.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+
+
+import tensorflow as tf
+
+from tensorflow.contrib.session_bundle import exporter
+from tensorflow.contrib.session_bundle import gc
+from tensorflow.contrib.session_bundle import manifest_pb2
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.platform import flags
+from tensorflow.python.platform import gfile
+
+
+FLAGS = flags.FLAGS
+
+GLOBAL_STEP = 222
+
+
+def tearDownModule():
+ gfile.DeleteRecursively(tf.test.get_temp_dir())
+
+
+class SaveRestoreShardedTest(tf.test.TestCase):
+
+ def doBasicsOneExportPath(self,
+ export_path,
+ clear_devices=False,
+ global_step=GLOBAL_STEP,
+ sharded=True):
+ # Build a graph with 2 parameter nodes on different devices.
+ tf.reset_default_graph()
+ with tf.Session(
+ target="",
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+ # v2 is an unsaved variable derived from v0 and v1. It is used to
+ # exercise the ability to run an init op when restoring a graph.
+ with sess.graph.device("/cpu:0"):
+ v0 = tf.Variable(10, name="v0")
+ with sess.graph.device("/cpu:1"):
+ v1 = tf.Variable(20, name="v1")
+ v2 = tf.Variable(1, name="v2", trainable=False, collections=[])
+ assign_v2 = tf.assign(v2, tf.add(v0, v1))
+ init_op = tf.group(assign_v2, name="init_op")
+
+ tf.add_to_collection("v", v0)
+ tf.add_to_collection("v", v1)
+ tf.add_to_collection("v", v2)
+
+ global_step_tensor = tf.Variable(global_step, name="global_step")
+ named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1}
+ signatures = {
+ "foo": exporter.regression_signature(input_tensor=v0,
+ output_tensor=v1),
+ "generic": exporter.generic_signature(named_tensor_bindings)
+ }
+
+ asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt")
+ asset_file = tf.constant(asset_filepath_orig, name="filename42")
+ tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file)
+
+ with gfile.FastGFile(asset_filepath_orig, "w") as f:
+ f.write("your data here")
+ assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)
+
+ ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt")
+ with gfile.FastGFile(ignored_asset, "w") as f:
+ f.write("additional data here")
+
+ tf.initialize_all_variables().run()
+
+ # Run an export.
+ save = tf.train.Saver({"v0": v0,
+ "v1": v1},
+ restore_sequentially=True,
+ sharded=sharded)
+ export = exporter.Exporter(save)
+ export.init(sess.graph.as_graph_def(),
+ init_op=init_op,
+ clear_devices=clear_devices,
+ default_graph_signature=exporter.classification_signature(
+ input_tensor=v0),
+ named_graph_signatures=signatures,
+ assets_collection=assets_collection)
+ export.export(export_path,
+ global_step_tensor,
+ sess,
+ exports_to_keep=gc.largest_export_versions(2))
+
+ # Restore graph.
+ compare_def = tf.get_default_graph().as_graph_def()
+ tf.reset_default_graph()
+ with tf.Session(
+ target="",
+ config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
+ save = tf.train.import_meta_graph(
+ os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.META_GRAPH_DEF_FILENAME))
+ self.assertIsNotNone(save)
+ meta_graph_def = save.export_meta_graph()
+ collection_def = meta_graph_def.collection_def
+
+ # Validate custom graph_def.
+ graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value
+ self.assertEquals(len(graph_def_any), 1)
+ graph_def = tf.GraphDef()
+ graph_def_any[0].Unpack(graph_def)
+ if clear_devices:
+ for node in compare_def.node:
+ node.device = ""
+ self.assertProtoEquals(compare_def, graph_def)
+
+ # Validate init_op.
+ init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value
+ self.assertEquals(len(init_ops), 1)
+ self.assertEquals(init_ops[0], "init_op")
+
+ # Validate signatures.
+ signatures_any = collection_def[exporter.SIGNATURES_KEY].any_list.value
+ self.assertEquals(len(signatures_any), 1)
+ signatures = manifest_pb2.Signatures()
+ signatures_any[0].Unpack(signatures)
+ default_signature = signatures.default_signature
+ self.assertEqual(
+ default_signature.classification_signature.input.tensor_name, "v0:0")
+ bindings = signatures.named_signatures["generic"].generic_signature.map
+ self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0")
+ self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0")
+ read_foo_signature = (
+ signatures.named_signatures["foo"].regression_signature)
+ self.assertEquals(read_foo_signature.input.tensor_name, "v0:0")
+ self.assertEquals(read_foo_signature.output.tensor_name, "v1:0")
+
+ # Validate the assets.
+ assets_any = collection_def[exporter.ASSETS_KEY].any_list.value
+ self.assertEquals(len(assets_any), 1)
+ asset = manifest_pb2.AssetFile()
+ assets_any[0].Unpack(asset)
+ assets_path = os.path.join(export_path,
+ exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.ASSETS_DIRECTORY,
+ "hello42.txt")
+ asset_contents = gfile.GFile(assets_path).read()
+ self.assertEqual(asset_contents, "your data here")
+ self.assertEquals("hello42.txt", asset.filename)
+ self.assertEquals("filename42:0", asset.tensor_binding.tensor_name)
+ ignored_asset_path = os.path.join(export_path,
+ exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.ASSETS_DIRECTORY,
+ "ignored.txt")
+ self.assertFalse(gfile.Exists(ignored_asset_path))
+
+ # Validate graph restoration.
+ if sharded:
+ save.restore(sess,
+ os.path.join(
+ export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.VARIABLES_FILENAME_PATTERN))
+ else:
+ save.restore(sess,
+ os.path.join(
+ export_path, exporter.VERSION_FORMAT_SPECIFIER %
+ global_step, exporter.VARIABLES_FILENAME))
+ self.assertEqual(10, tf.get_collection("v")[0].eval())
+ self.assertEqual(20, tf.get_collection("v")[1].eval())
+ tf.get_collection(exporter.INIT_OP_KEY)[0].run()
+ self.assertEqual(30, tf.get_collection("v")[2].eval())
+
+ def testDuplicateExportRaisesError(self):
+ export_path = os.path.join(tf.test.get_temp_dir(), "export_duplicates")
+ self.doBasicsOneExportPath(export_path)
+ self.assertRaises(RuntimeError, self.doBasicsOneExportPath, export_path)
+
+ def testBasics(self):
+ export_path = os.path.join(tf.test.get_temp_dir(), "export")
+ self.doBasicsOneExportPath(export_path)
+
+ def testBasicsNoShard(self):
+ export_path = os.path.join(tf.test.get_temp_dir(), "export_no_shard")
+ self.doBasicsOneExportPath(export_path, sharded=False)
+
+ def testClearDevice(self):
+ export_path = os.path.join(tf.test.get_temp_dir(), "export_clear_device")
+ self.doBasicsOneExportPath(export_path, clear_devices=True)
+
+ def testGC(self):
+ export_path = os.path.join(tf.test.get_temp_dir(), "gc")
+ self.doBasicsOneExportPath(export_path, global_step=100)
+ self.assertEquals(gfile.ListDirectory(export_path), ["00000100"])
+ self.doBasicsOneExportPath(export_path, global_step=101)
+ self.assertEquals(
+ sorted(gfile.ListDirectory(export_path)), ["00000100", "00000101"])
+ self.doBasicsOneExportPath(export_path, global_step=102)
+ self.assertEquals(
+ sorted(gfile.ListDirectory(export_path)), ["00000101", "00000102"])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py
new file mode 100644
index 0000000000..ad7389d96f
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/gc.py
@@ -0,0 +1,204 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""System for specifying garbage collection (GC) of path based data.
+
+This framework allows for GC of data specified by path names, for example files
+on disk. gc.Path objects each represent a single item stored at a path and may
+be a base directory,
+ /tmp/exports/0/...
+ /tmp/exports/1/...
+ ...
+or a fully qualified file,
+ /tmp/train-1.ckpt
+ /tmp/train-2.ckpt
+ ...
+
+A gc filter function takes and returns a list of gc.Path items. Filter
+functions are responsible for selecting Path items for preservation or deletion.
+Note that functions should always return a sorted list.
+
+For example,
+ base_dir = "/tmp"
+ # create the directories
+ for e in xrange(10):
+ os.mkdir("%s/%d" % (base_dir, e), 0o755)
+
+ # create a simple parser that pulls the export_version from the directory
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ path_list = gc.get_paths("/tmp", parser) # contains all ten Paths
+
+ every_fifth = gc.mod_export_version(5)
+ print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"]
+
+ largest_three = gc.largest_export_versions(3)
+ print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"]
+
+ both = gc.union(every_fifth, largest_three)
+ print both(all_paths) # shows ["/tmp/0", "/tmp/5",
+ # "/tmp/7", "/tmp/8", "/tmp/9"]
+ # delete everything not in 'both'
+ to_delete = gc.negation(both)
+ for p in to_delete(all_paths):
+ gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2",
+ # "/tmp/3", "/tmp/4", "/tmp/6",
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import heapq
+import math
+import os
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.platform import gfile
+
+Path = collections.namedtuple('Path', 'path export_version')
+
+
+def largest_export_versions(n):
+ """Creates a filter that keeps the largest n export versions.
+
+ Args:
+ n: number of versions to keep.
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ heap = []
+ for idx, path in enumerate(paths):
+ if path.export_version:
+ heapq.heappush(heap, (path.export_version, idx))
+ keepers = [paths[i] for _, i in heapq.nlargest(n, heap)]
+ return sorted(keepers)
+
+ return keep
+
+
+def one_of_every_n_export_versions(n):
+ """Creates a filter that keeps one of every n export versions.
+
+ Args:
+ n: interval size.
+
+ Returns:
+ A filter function that keeps exactly one path from each interval
+ [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an
+ interval the largest is kept.
+ """
+ def keep(paths):
+ keeper_map = {} # map from interval to largest path seen in that interval
+ for p in paths:
+ if p.export_version is None:
+ # Skip missing export_versions.
+ continue
+ # Find the interval (with a special case to map export_version = 0 to
+ # interval 0.
+ interval = math.floor(
+ (p.export_version - 1) / n) if p.export_version else 0
+ existing = keeper_map.get(interval, None)
+ if (not existing) or (existing.export_version < p.export_version):
+ keeper_map[interval] = p
+ return sorted(keeper_map.values())
+
+ return keep
+
+
+def mod_export_version(n):
+ """Creates a filter that keeps every export that is a multiple of n.
+
+ Args:
+ n: step size.
+
+ Returns:
+ A filter function that keeps paths where export_version % n == 0.
+ """
+ def keep(paths):
+ keepers = []
+ for p in paths:
+ if p.export_version % n == 0:
+ keepers.append(p)
+ return sorted(keepers)
+ return keep
+
+
+def union(lf, rf):
+ """Creates a filter that keeps the union of two filters.
+
+ Args:
+ lf: first filter
+ rf: second filter
+
+ Returns:
+ A filter function that keeps the n largest paths.
+ """
+ def keep(paths):
+ l = set(lf(paths))
+ r = set(rf(paths))
+ return sorted(list(l|r))
+ return keep
+
+
+def negation(f):
+ """Negate a filter.
+
+ Args:
+ f: filter function to invert
+
+ Returns:
+ A filter function that returns the negation of f.
+ """
+ def keep(paths):
+ l = set(paths)
+ r = set(f(paths))
+ return sorted(list(l-r))
+ return keep
+
+
+def get_paths(base_dir, parser):
+ """Gets a list of Paths in a given directory.
+
+ Args:
+ base_dir: directory.
+ parser: a function which gets the raw Path and can augment it with
+ information such as the export_version, or ignore the path by returning
+ None. An example parser may extract the export version from a path
+ such as "/tmp/exports/100" an another may extract from a full file
+ name such as "/tmp/checkpoint-99.out".
+
+ Returns:
+ A list of Paths contained in the base directory with the parsing function
+ applied.
+ By default the following fields are populated,
+ - Path.path
+ The parsing function is responsible for populating,
+ - Path.export_version
+ """
+ raw_paths = gfile.ListDirectory(base_dir)
+ paths = []
+ for r in raw_paths:
+ p = parser(Path(os.path.join(base_dir, r), None))
+ if p:
+ paths.append(p)
+ return sorted(paths)
diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py
new file mode 100644
index 0000000000..c71cbb9e5f
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/gc_test.py
@@ -0,0 +1,115 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Tests for session_bundle.gc."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+
+
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+import tensorflow as tf
+
+from tensorflow.contrib.session_bundle import gc
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import gfile
+
+
+def tearDownModule():
+ gfile.DeleteRecursively(tf.test.get_temp_dir())
+
+
+class GcTest(test_util.TensorFlowTestCase):
+
+ def testLargestExportVersions(self):
+ paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
+ newest = gc.largest_export_versions(2)
+ n = newest(paths)
+ self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
+
+ def testModExportVersion(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.mod_export_version(2)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
+ mod = gc.mod_export_version(3)
+ self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
+
+ def testOneOfEveryNExportVersions(self):
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3),
+ gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 3), gc.Path("/foo", 6),
+ gc.Path("/foo", 8), gc.Path("/foo", 33)])
+
+ def testOneOfEveryNExportVersionsZero(self):
+ # Zero is a special case since it gets rolled into the first interval.
+ # Test that here.
+ paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
+ one_of = gc.one_of_every_n_export_versions(3)
+ self.assertEquals(one_of(paths),
+ [gc.Path("/foo", 0), gc.Path("/foo", 5)])
+
+ def testUnion(self):
+ paths = []
+ for i in xrange(10):
+ paths.append(gc.Path("/foo", i))
+ f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
+ self.assertEquals(
+ f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3),
+ gc.Path("/foo", 6), gc.Path("/foo", 7),
+ gc.Path("/foo", 8), gc.Path("/foo", 9)])
+
+ def testNegation(self):
+ paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6),
+ gc.Path("/foo", 9)]
+ mod = gc.negation(gc.mod_export_version(2))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
+ mod = gc.negation(gc.mod_export_version(3))
+ self.assertEquals(
+ mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
+
+ def testPathsWithParse(self):
+ base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse")
+ self.assertFalse(gfile.Exists(base_dir))
+ for p in xrange(3):
+ gfile.MakeDirs(os.path.join(base_dir, "%d" % p))
+ # add a base_directory to ignore
+ gfile.MakeDirs(os.path.join(base_dir, "ignore"))
+
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match("^" + base_dir + "/(\\d+)$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ self.assertEquals(
+ gc.get_paths(base_dir, parser=parser),
+ [gc.Path(os.path.join(base_dir, "0"), 0),
+ gc.Path(os.path.join(base_dir, "1"), 1),
+ gc.Path(os.path.join(base_dir, "2"), 2)])
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/session_bundle/manifest.proto b/tensorflow/contrib/session_bundle/manifest.proto
new file mode 100644
index 0000000000..499c1bcfd8
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/manifest.proto
@@ -0,0 +1,70 @@
+syntax = "proto3";
+
+package tensorflow.contrib;
+
+// Signatures of model export.
+message Signatures {
+ // Default signature of the graph.
+ // WARNING(break-tutorial-inline-code): The following code snippet is
+ // in-lined in tutorials, please update tutorial documents accordingly
+ // whenever code changes.
+ Signature default_signature = 1;
+
+ // Named signatures of the graph.
+ map<string, Signature> named_signatures = 2;
+};
+
+// A binding to a tensor including the name and, possibly in the future, type
+// or other metadata. For example, this may specify whether a tensor supports
+// batch vs single inference.
+message TensorBinding {
+ // The name of the tensor to bind to.
+ string tensor_name = 1;
+};
+
+// An asset file or set of sharded files with the same name that will be bound
+// to a tensor at init / session_bundle load time.
+message AssetFile {
+ // The tensor to bind the asset filename to.
+ TensorBinding tensor_binding = 1;
+ // The filename within the assets directory. Note: does not include the base
+ // path or asset directory prefix. Base paths can and will change when models
+ // are deployed for serving.
+ string filename = 2;
+}
+
+// A Signature specifies the inputs and outputs of commonly used graphs.
+message Signature {
+ oneof type {
+ RegressionSignature regression_signature = 1;
+ ClassificationSignature classification_signature = 2;
+ GenericSignature generic_signature = 3;
+ }
+};
+
+// RegressionSignature specifies a graph that takes an input and returns an
+// output.
+message RegressionSignature {
+ TensorBinding input = 1;
+ TensorBinding output = 2;
+};
+
+// ClassificationSignature specifies a graph that takes an input and returns
+// classes and their scores.
+// WARNING(break-tutorial-inline-code): The following code snippet is
+// in-lined in tutorials, please update tutorial documents accordingly
+// whenever code changes.
+message ClassificationSignature {
+ TensorBinding input = 1;
+ TensorBinding classes = 2;
+ TensorBinding scores = 3;
+};
+
+// GenericSignature specifies a map from logical name to Tensor name.
+// Typical application of GenericSignature is to use a single GenericSignature
+// that includes all of the Tensor nodes and target names that may be useful at
+// serving, analysis or debugging time. The recommended name for this signature
+// in the ModelManifest is "generic_bindings".
+message GenericSignature {
+ map<string, TensorBinding> map = 1;
+};
diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc
new file mode 100644
index 0000000000..8715492af4
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/session_bundle.cc
@@ -0,0 +1,181 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/session_bundle.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/manifest.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace {
+
+// Create a session using the given options and load the graph.
+Status CreateSessionFromGraphDef(
+ const tensorflow::SessionOptions& options, const GraphDef& graph,
+ std::unique_ptr<tensorflow::Session>* session) {
+ session->reset(NewSession(options));
+ return (*session)->Create(graph);
+}
+
+Status GetMetaGraphDefFromExport(const StringPiece export_dir,
+ tensorflow::MetaGraphDef* meta_graph_def) {
+ const string meta_graph_def_path =
+ tensorflow::io::JoinPath(export_dir, kMetaGraphDefFilename);
+ return ReadBinaryProto(Env::Default(), meta_graph_def_path, meta_graph_def);
+}
+
+// Creates a string tensor.
+Tensor CreateStringTensor(const string& value) {
+ Tensor tensor(DT_STRING, TensorShape({}));
+ tensor.scalar<string>()() = value;
+ return tensor;
+}
+
+// Adds Assets related tensors (assets_dir and asset files) to the inputs.
+void AddAssetsTensorsToInputs(const StringPiece export_dir,
+ const std::vector<AssetFile>& asset_files,
+ std::vector<std::pair<string, Tensor>>* inputs) {
+ if (!asset_files.empty()) {
+ for (auto& asset : asset_files) {
+ Tensor assets_file_tensor = CreateStringTensor(tensorflow::io::JoinPath(
+ tensorflow::io::JoinPath(export_dir, kAssetsDirectory),
+ asset.filename()));
+ inputs->push_back(
+ {asset.tensor_binding().tensor_name(), assets_file_tensor});
+ }
+ }
+}
+
+// Historically, model exporter(exporter.py) takes only saver with
+// sharded=True, and therefore always exports checkpoint in pattern file names.
+// In practice, instead of training from scratch and export directly, we
+// usually want to restore from existing checkpoints and then export directly.
+// To support such case, model exporter now supports reusing saver object
+// restored from existing checkpoint, that may have sharded=False - it will
+// then export checkpoint file in plain file name.
+// This method is to support models exported by both types of saver object.
+// The change is backward-compatible, therefore no changes are needed for
+// existing model exports.
+string GetVariablesFilename(const StringPiece export_dir) {
+ const char kVariablesFilename[] = "export";
+ const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?";
+ if (Env::Default()->FileExists(
+ tensorflow::io::JoinPath(export_dir, kVariablesFilename))) {
+ return tensorflow::io::JoinPath(export_dir, kVariablesFilename);
+ } else {
+ return tensorflow::io::JoinPath(export_dir, kVariablesFilenamePattern);
+ }
+}
+
+Status RunRestoreOp(const StringPiece export_dir,
+ const std::vector<AssetFile>& asset_files,
+ const StringPiece restore_op_name,
+ const StringPiece variables_filename_const_op_name,
+ tensorflow::Session* session) {
+ LOG(INFO) << "Running restore op for SessionBundle";
+ Tensor variables_tensor = CreateStringTensor(
+ GetVariablesFilename(export_dir));
+ std::vector<std::pair<string, Tensor>> inputs = {
+ {variables_filename_const_op_name.ToString(), variables_tensor}};
+ AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
+ return session->Run(inputs, {}, {restore_op_name.ToString()}, nullptr);
+}
+
+Status RunInitOp(const StringPiece export_dir,
+ const std::vector<AssetFile>& asset_files,
+ const StringPiece init_op_name, tensorflow::Session* session) {
+ LOG(INFO) << "Running init op for SessionBundle";
+ std::vector<std::pair<string, Tensor>> inputs;
+ AddAssetsTensorsToInputs(export_dir, asset_files, &inputs);
+ return session->Run(inputs, {}, {init_op_name.ToString()}, nullptr);
+}
+
+} // namespace
+
+tensorflow::Status LoadSessionBundleFromPath(
+ const tensorflow::SessionOptions& options, const StringPiece export_dir,
+ SessionBundle* bundle) {
+ LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir;
+ TF_RETURN_IF_ERROR(
+ GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def)));
+
+ auto collection_def = bundle->meta_graph_def.collection_def();
+ if (collection_def.find(kGraphKey) != collection_def.end()) {
+ // Use serving graph_def in MetaGraphDef collection_def.
+ if (collection_def[kGraphKey].any_list().value_size() != 1) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected exactly one serving GraphDef in : ",
+ bundle->meta_graph_def.DebugString()));
+ }
+ tensorflow::GraphDef graph_def;
+ collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def);
+ TF_RETURN_IF_ERROR(
+ CreateSessionFromGraphDef(options, graph_def, &bundle->session));
+ } else {
+ // Fallback to use the graph_def in the MetaGraphDef.
+ const tensorflow::GraphDef& graph_def = bundle->meta_graph_def.graph_def();
+ TF_RETURN_IF_ERROR(
+ CreateSessionFromGraphDef(options, graph_def, &bundle->session));
+ }
+
+ std::vector<AssetFile> asset_files;
+ auto any_assets = collection_def[kAssetsKey].any_list().value();
+ for (const auto any_asset : any_assets) {
+ AssetFile asset_file;
+ any_asset.UnpackTo(&asset_file);
+ asset_files.push_back(asset_file);
+ }
+
+ TF_RETURN_IF_ERROR(
+ RunRestoreOp(export_dir, asset_files,
+ bundle->meta_graph_def.saver_def().restore_op_name(),
+ bundle->meta_graph_def.saver_def().filename_tensor_name(),
+ bundle->session.get()));
+
+ if (collection_def.find(kInitOpKey) != collection_def.end()) {
+ if (collection_def[kInitOpKey].node_list().value_size() != 1) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected exactly one serving init op in : ",
+ bundle->meta_graph_def.DebugString()));
+ }
+ return RunInitOp(export_dir, asset_files,
+ collection_def[kInitOpKey].node_list().value(0),
+ bundle->session.get());
+ }
+
+ LOG(INFO) << "Done loading SessionBundle";
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/session_bundle.h b/tensorflow/contrib/session_bundle/session_bundle.h
new file mode 100644
index 0000000000..92ed0132b8
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/session_bundle.h
@@ -0,0 +1,59 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Low-level functionality for setting up a inference Session.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
+
+#include <memory>
+
+#include "tensorflow/contrib/session_bundle/manifest.pb.h"
+#include "tensorflow/contrib/session_bundle/signature.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace contrib {
+
+const char kMetaGraphDefFilename[] = "export.meta";
+const char kAssetsDirectory[] = "assets";
+const char kInitOpKey[] = "serving_init_op";
+const char kAssetsKey[] = "serving_assets";
+const char kGraphKey[] = "serving_graph";
+
+// Data and objects loaded from a python Exporter export.
+// WARNING(break-tutorial-inline-code): The following code snippet is
+// in-lined in tutorials, please update tutorial documents accordingly
+// whenever code changes.
+struct SessionBundle {
+ std::unique_ptr<tensorflow::Session> session;
+ tensorflow::MetaGraphDef meta_graph_def;
+};
+
+// Loads a manifest and initialized session using the output of an Exporter
+// using the format defined at go/tf-exporter.
+tensorflow::Status LoadSessionBundleFromPath(
+ const tensorflow::SessionOptions& options,
+ const tensorflow::StringPiece export_dir, SessionBundle* bundle);
+
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_
diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc
new file mode 100644
index 0000000000..2bf09ad438
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc
@@ -0,0 +1,102 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/session_bundle.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/test_util.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_shape.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/public/session.h"
+#include "tensorflow/core/public/session_options.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace {
+
+TEST(LoadSessionBundleFromPath, Basic) {
+ const string export_path = test_util::TestSrcDirPath(
+ "session_bundle/example/half_plus_two/00000123");
+ tensorflow::SessionOptions options;
+ SessionBundle bundle;
+ TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle));
+
+ const string asset_path =
+ tensorflow::io::JoinPath(export_path, kAssetsDirectory);
+ // Validate the assets behavior.
+ std::vector<Tensor> path_outputs;
+ TF_ASSERT_OK(bundle.session->Run({}, {"filename1:0", "filename2:0"}, {},
+ &path_outputs));
+ ASSERT_EQ(2, path_outputs.size());
+ // Validate the two asset file tensors are set by the init_op and include the
+ // base_path and asset directory.
+ test::ExpectTensorEqual<string>(
+ test::AsTensor<string>(
+ {tensorflow::io::JoinPath(asset_path, "hello1.txt")},
+ TensorShape({})),
+ path_outputs[0]);
+ test::ExpectTensorEqual<string>(
+ test::AsTensor<string>(
+ {tensorflow::io::JoinPath(asset_path, "hello2.txt")},
+ TensorShape({})),
+ path_outputs[1]);
+
+ // Validate the half plus two behavior.
+ Tensor input = test::AsTensor<float>({0, 1, 2, 3}, TensorShape({4, 1}));
+
+ // Recover the Tensor names of our inputs and outputs.
+ auto collection_def = bundle.meta_graph_def.collection_def();
+ Signatures signatures;
+ ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size());
+ collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures);
+ ASSERT_TRUE(signatures.default_signature().has_regression_signature());
+ const tensorflow::contrib::RegressionSignature regression_signature =
+ signatures.default_signature().regression_signature();
+
+ const string input_name = regression_signature.input().tensor_name();
+ const string output_name = regression_signature.output().tensor_name();
+
+ std::vector<Tensor> outputs;
+ TF_ASSERT_OK(
+ bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs));
+ ASSERT_EQ(outputs.size(), 1);
+ test::ExpectTensorEqual<float>(
+ outputs[0], test::AsTensor<float>({2, 2.5, 3, 3.5}, TensorShape({4, 1})));
+}
+
+TEST(LoadSessionBundleFromPath, BadExportPath) {
+ const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot");
+ tensorflow::SessionOptions options;
+ options.target = "local";
+ SessionBundle bundle;
+ const auto status = LoadSessionBundleFromPath(options, export_path, &bundle);
+ ASSERT_FALSE(status.ok());
+ const string msg = status.ToString();
+ EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg;
+}
+
+} // namespace
+} // namespace contrib
+} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc
new file mode 100644
index 0000000000..3550a7d10d
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/signature.cc
@@ -0,0 +1,270 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/signature.h"
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/manifest.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace {
+
+// Returns OK if the input and output batch sizes match.
+Status BatchSizesMatch(const Tensor& input, const Tensor& output) {
+ // Ensure the number of outputs match the number of inputs.
+ if (input.dim_size(0) != output.dim_size(0)) {
+ return errors::Internal(
+ strings::StrCat("Input batch size did not match output batch size: ",
+ input.dim_size(0), " vs. ", output.dim_size(0)));
+ }
+ return Status::OK();
+}
+} // namespace
+
+Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
+ Signatures* signatures) {
+ auto collection_def = meta_graph_def.collection_def();
+ auto any_list = collection_def[kSignaturesKey].any_list();
+ if (any_list.value_size() != 1) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected exactly one signatures proto in : ",
+ meta_graph_def.DebugString()));
+ }
+ any_list.value(0).UnpackTo(signatures);
+ return Status::OK();
+}
+
+Status SetSignatures(const Signatures& signatures,
+ tensorflow::MetaGraphDef* meta_graph_def) {
+ auto& collection_def = *(meta_graph_def->mutable_collection_def());
+ auto* any_list = collection_def[kSignaturesKey].mutable_any_list();
+ any_list->mutable_value()->Clear();
+ any_list->mutable_value()->Add()->PackFrom(signatures);
+ return Status::OK();
+}
+
+Status GetClassificationSignature(
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ ClassificationSignature* signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ if (!signatures.has_default_signature()) {
+ return errors::FailedPrecondition(strings::StrCat(
+ "Expected a default signature in: ", signatures.DebugString()));
+ }
+ if (!signatures.default_signature().has_classification_signature()) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected a classification signature in: ",
+ signatures.default_signature().DebugString()));
+ }
+ *signature = signatures.default_signature().classification_signature();
+ return Status::OK();
+}
+
+Status GetNamedClassificationSignature(
+ const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
+ ClassificationSignature* signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ const auto& it = signatures.named_signatures().find(name);
+ if (it == signatures.named_signatures().end()) {
+ return errors::NotFound(strings::StrCat("Missing signature named \"", name,
+ "\" in: ",
+ signatures.DebugString()));
+ }
+ if (!it->second.has_classification_signature()) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected a classification signature for name \"", name,
+ "\" in: ", it->second.DebugString()));
+ }
+ *signature = it->second.classification_signature();
+ return Status::OK();
+}
+
+Status RunClassification(const ClassificationSignature& signature,
+ const Tensor& input, Session* session, Tensor* classes,
+ Tensor* scores) {
+ std::vector<string> output_tensor_names;
+ if (classes) {
+ output_tensor_names.push_back(signature.classes().tensor_name());
+ }
+ if (scores) {
+ output_tensor_names.push_back(signature.scores().tensor_name());
+ }
+ // Run the graph with our inputs and outputs.
+ std::vector<Tensor> outputs;
+ const Status run_status =
+ session->Run({{signature.input().tensor_name(), input}},
+ output_tensor_names, {}, &outputs);
+ if (!run_status.ok()) {
+ return run_status;
+ }
+ // Ensure the output is shaped how we expect.
+ // There should be one string Tensor of shape,
+ // [batch_size, num_recommendations].
+ if (outputs.size() != output_tensor_names.size()) {
+ return errors::Internal(
+ strings::StrCat("Expected ", output_tensor_names.size(),
+ " output tensor(s). Got: ", outputs.size()));
+ }
+ if (classes) {
+ *classes = outputs[0];
+ TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes));
+ }
+ if (scores) {
+ *scores = outputs[classes ? 1 : 0];
+ TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores));
+ }
+ return Status::OK();
+}
+
+Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
+ RegressionSignature* signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ if (!signatures.has_default_signature()) {
+ return errors::FailedPrecondition(strings::StrCat(
+ "Expected a default signature in: ", signatures.DebugString()));
+ }
+ if (!signatures.default_signature().has_regression_signature()) {
+ return errors::FailedPrecondition(
+ strings::StrCat("Expected a regression signature in: ",
+ signatures.default_signature().DebugString()));
+ }
+ *signature = signatures.default_signature().regression_signature();
+ return Status::OK();
+}
+
+Status RunRegression(const RegressionSignature& signature,
+ const Tensor& regression_input, Session* session,
+ Tensor* regression_output) {
+ std::vector<string> output_tensor_names;
+ if (regression_output) {
+ output_tensor_names.push_back(signature.output().tensor_name());
+ }
+ // Run the graph with our inputs and outputs.
+ std::vector<Tensor> outputs;
+ const Status run_status =
+ session->Run({{signature.input().tensor_name(), regression_input}},
+ output_tensor_names, {}, &outputs);
+ if (!run_status.ok()) {
+ return run_status;
+ }
+ // Ensure the regression score output is shaped how we expect.
+ // There should be one float Tensor of shape,
+ // [batch_size, num_recommendations].
+ if (outputs.size() != output_tensor_names.size()) {
+ return errors::Internal(
+ strings::StrCat("Expected ", output_tensor_names.size(),
+ " output tensor(s). Got: ", outputs.size()));
+ }
+ if (regression_output) {
+ *regression_output = outputs[0];
+ TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output));
+ }
+ return Status::OK();
+}
+
+Status GetGenericSignature(const string& name,
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ GenericSignature* signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ const auto& it = signatures.named_signatures().find(name);
+ if (it == signatures.named_signatures().end()) {
+ return errors::InvalidArgument(
+ strings::StrCat("Missing generic signature named \"", name, "\" in ",
+ signatures.DebugString()));
+ }
+ if (!it->second.has_generic_signature()) {
+ return errors::InvalidArgument(strings::StrCat(
+ "Expected a generic signature: ", it->second.DebugString()));
+ }
+ *signature = it->second.generic_signature();
+ return Status::OK();
+}
+
+Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
+ Signature* default_signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ *default_signature = signatures.default_signature();
+ return Status::OK();
+}
+
+Status GetNamedSignature(const string& name,
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ Signature* signature) {
+ Signatures signatures;
+ TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures));
+ const auto& it = signatures.named_signatures().find(name);
+ if (it == signatures.named_signatures().end()) {
+ return errors::NotFound(strings::StrCat("Missing signature named \"", name,
+ "\" in: ",
+ signatures.DebugString()));
+ }
+ *signature = it->second;
+ return Status::OK();
+}
+
+Status BindGenericInputs(const GenericSignature& signature,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ std::vector<std::pair<string, Tensor>>* bound_inputs) {
+ const protobuf::Map<string, contrib::TensorBinding>& bindings =
+ signature.map();
+
+ for (const auto& entry : inputs) {
+ const auto mapped = bindings.find(entry.first);
+ if (mapped == bindings.end()) {
+ return errors::NotFound(
+ strings::StrCat("Could not find generic binding for: ", entry.first));
+ }
+ bound_inputs->push_back({mapped->second.tensor_name(), entry.second});
+ }
+ return Status::OK();
+}
+
+Status BindGenericNames(const GenericSignature& signature,
+ const std::vector<string>& input_names,
+ std::vector<string>* bound_names) {
+ const protobuf::Map<string, contrib::TensorBinding>& bindings =
+ signature.map();
+
+ for (const string& entry : input_names) {
+ const auto mapped = bindings.find(entry);
+ if (mapped == bindings.end()) {
+ return errors::NotFound(
+ strings::StrCat("Could not find generic binding for: ", entry));
+ }
+ bound_names->push_back(mapped->second.tensor_name());
+ }
+ return Status::OK();
+}
+
+} // namespace contrib
+} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/signature.h b/tensorflow/contrib/session_bundle/signature.h
new file mode 100644
index 0000000000..98d6e56601
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/signature.h
@@ -0,0 +1,123 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+// Helpers for working with TensorFlow exports and their signatures.
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
+
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "tensorflow/contrib/session_bundle/manifest.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+#include "tensorflow/core/protobuf/meta_graph.pb.h"
+#include "tensorflow/core/protobuf/saver.pb.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace contrib {
+
+const char kSignaturesKey[] = "serving_signatures";
+
+// Get Signatures from a MetaGraphDef.
+Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def,
+ Signatures* signatures);
+
+// (Re)set Signatures in a MetaGraphDef.
+Status SetSignatures(const Signatures& signatures,
+ tensorflow::MetaGraphDef* meta_graph_def);
+
+// Gets a ClassificationSignature from a MetaGraphDef's default signature.
+// Returns an error if the default signature is not a ClassificationSignature,
+// or does not exist.
+Status GetClassificationSignature(
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ ClassificationSignature* signature);
+
+// Gets a named ClassificationSignature from a MetaGraphDef.
+// Returns an error if a ClassificationSignature with the given name does
+// not exist.
+Status GetNamedClassificationSignature(
+ const string& name, const tensorflow::MetaGraphDef& meta_graph_def,
+ ClassificationSignature* signature);
+
+// Gets a RegressionSignature from a MetaGraphDef's default signature.
+// Returns an error if the default signature is not a RegressionSignature,
+// or does not exist.
+Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def,
+ RegressionSignature* signature);
+
+// Runs a classification using the provided signature and initialized Session.
+// input: input batch of items to classify
+// classes: output batch of classes; may be null if not needed
+// scores: output batch of scores; may be null if not needed
+// Validates sizes of the inputs and outputs are consistent (e.g., input
+// batch size equals output batch sizes).
+// Does not do any type validation.
+Status RunClassification(const ClassificationSignature& signature,
+ const Tensor& input, Session* session, Tensor* classes,
+ Tensor* scores);
+
+// Runs regression using the provided signature and initialized Session.
+// input: input batch of items to run the regression model against
+// output: output targets
+// Validates sizes of the inputs and outputs are consistent (e.g., input
+// batch size equals output batch sizes).
+// Does not do any type validation.
+Status RunRegression(const RegressionSignature& signature, const Tensor& input,
+ Session* session, Tensor* output);
+
+// Gets the named GenericSignature from a MetaGraphDef.
+// Returns an error if a GenericSignature with the given name does not exist.
+Status GetGenericSignature(const string& name,
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ GenericSignature* signature);
+
+// Gets the default signature from a MetaGraphDef.
+Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def,
+ Signature* default_signature);
+
+// Gets a named Signature from a MetaGraphDef.
+// Returns an error if a Signature with the given name does not exist.
+Status GetNamedSignature(const string& name,
+ const tensorflow::MetaGraphDef& meta_graph_def,
+ Signature* default_signature);
+
+// Binds TensorFlow inputs specified by the caller using the logical names
+// specified at Graph export time, to the actual Graph names.
+// Returns an error if any of the inputs do not have a binding in the export's
+// MetaGraphDef.
+Status BindGenericInputs(const GenericSignature& signature,
+ const std::vector<std::pair<string, Tensor>>& inputs,
+ std::vector<std::pair<string, Tensor>>* bound_inputs);
+
+// Binds the input names specified by the caller using the logical names
+// specified at Graph export time, to the actual Graph names. This is useful
+// for binding names of both the TensorFlow output tensors and target nodes,
+// with the latter (target nodes) being optional and rarely used (if ever) at
+// serving time.
+// Returns an error if any of the input names do not have a binding in the
+// export's MetaGraphDef.
+Status BindGenericNames(const GenericSignature& signature,
+ const std::vector<string>& input_names,
+ std::vector<string>* bound_names);
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_
diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc
new file mode 100644
index 0000000000..abeaf23c0e
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/signature_test.cc
@@ -0,0 +1,602 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/signature.h"
+
+#include <memory>
+
+#include "google/protobuf/any.pb.h"
+#include "tensorflow/contrib/session_bundle/manifest.pb.h"
+#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_testutil.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/platform/test.h"
+#include "tensorflow/core/public/session.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace {
+
+static bool HasSubstr(const string& base, const string& substr) {
+ bool ok = StringPiece(base).contains(substr);
+ EXPECT_TRUE(ok) << base << ", expected substring " << substr;
+ return ok;
+}
+
+TEST(GetClassificationSignature, Basic) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ ClassificationSignature* input_signature =
+ signatures.mutable_default_signature()
+ ->mutable_classification_signature();
+ input_signature->mutable_input()->set_tensor_name("flow");
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status = GetClassificationSignature(meta_graph_def, &signature);
+ TF_ASSERT_OK(status);
+ EXPECT_EQ(signature.input().tensor_name(), "flow");
+}
+
+TEST(GetClassificationSignature, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_default_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status = GetClassificationSignature(meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected a classification signature"))
+ << status.error_message();
+}
+
+TEST(GetClassificationSignature, WrongSignatureType) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_default_signature()->mutable_regression_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status = GetClassificationSignature(meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected a classification signature"))
+ << status.error_message();
+}
+
+TEST(GetNamedClassificationSignature, Basic) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ ClassificationSignature* input_signature =
+ (*signatures.mutable_named_signatures())["foo"]
+ .mutable_classification_signature();
+ input_signature->mutable_input()->set_tensor_name("flow");
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status =
+ GetNamedClassificationSignature("foo", meta_graph_def, &signature);
+ TF_ASSERT_OK(status);
+ EXPECT_EQ(signature.input().tensor_name(), "flow");
+}
+
+TEST(GetNamedClassificationSignature, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status =
+ GetNamedClassificationSignature("foo", meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Missing signature named \"foo\""))
+ << status.error_message();
+}
+
+TEST(GetNamedClassificationSignature, WrongSignatureType) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ (*signatures.mutable_named_signatures())["foo"]
+ .mutable_regression_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ ClassificationSignature signature;
+ const Status status =
+ GetNamedClassificationSignature("foo", meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(
+ StringPiece(status.error_message())
+ .contains("Expected a classification signature for name \"foo\""))
+ << status.error_message();
+}
+
+TEST(GetRegressionSignature, Basic) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ RegressionSignature* input_signature =
+ signatures.mutable_default_signature()->mutable_regression_signature();
+ input_signature->mutable_input()->set_tensor_name("flow");
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ RegressionSignature signature;
+ const Status status = GetRegressionSignature(meta_graph_def, &signature);
+ TF_ASSERT_OK(status);
+ EXPECT_EQ(signature.input().tensor_name(), "flow");
+}
+
+TEST(GetRegressionSignature, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_default_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ RegressionSignature signature;
+ const Status status = GetRegressionSignature(meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected a regression signature"))
+ << status.error_message();
+}
+
+TEST(GetRegressionSignature, WrongSignatureType) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_default_signature()->mutable_classification_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ RegressionSignature signature;
+ const Status status = GetRegressionSignature(meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected a regression signature"))
+ << status.error_message();
+}
+
+TEST(GetNamedSignature, Basic) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ ClassificationSignature* input_signature =
+ (*signatures.mutable_named_signatures())["foo"]
+ .mutable_classification_signature();
+ input_signature->mutable_input()->set_tensor_name("flow");
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ Signature signature;
+ const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
+ TF_ASSERT_OK(status);
+ EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow");
+}
+
+TEST(GetNamedSignature, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ Signature signature;
+ const Status status = GetNamedSignature("foo", meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Missing signature named \"foo\""))
+ << status.error_message();
+}
+
+// MockSession used to test input and output interactions with a
+// tensorflow::Session.
+struct MockSession : public tensorflow::Session {
+ ~MockSession() override = default;
+
+ Status Create(const GraphDef& graph) override {
+ return errors::Unimplemented("Not implemented for mock.");
+ }
+
+ Status Extend(const GraphDef& graph) override {
+ return errors::Unimplemented("Not implemented for mock.");
+ }
+
+ // Sets the input and output arguments.
+ Status Run(const std::vector<std::pair<string, Tensor>>& inputs_arg,
+ const std::vector<string>& output_tensor_names_arg,
+ const std::vector<string>& target_node_names_arg,
+ std::vector<Tensor>* outputs_arg) override {
+ inputs = inputs_arg;
+ output_tensor_names = output_tensor_names_arg;
+ target_node_names = target_node_names_arg;
+ *outputs_arg = outputs;
+ return status;
+ }
+
+ Status Close() override {
+ return errors::Unimplemented("Not implemented for mock.");
+ }
+
+ // Arguments stored on a Run call.
+ std::vector<std::pair<string, Tensor>> inputs;
+ std::vector<string> output_tensor_names;
+ std::vector<string> target_node_names;
+
+ // Output argument set by Run; should be set before calling.
+ std::vector<Tensor> outputs;
+
+ // Return value for Run; should be set before calling.
+ Status status;
+};
+
+constexpr char kInputName[] = "in:0";
+constexpr char kClassesName[] = "classes:0";
+constexpr char kScoresName[] = "scores:0";
+
+class RunClassificationTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ signature_.mutable_input()->set_tensor_name(kInputName);
+ signature_.mutable_classes()->set_tensor_name(kClassesName);
+ signature_.mutable_scores()->set_tensor_name(kScoresName);
+ }
+
+ protected:
+ ClassificationSignature signature_;
+ Tensor input_tensor_;
+ Tensor classes_tensor_;
+ Tensor scores_tensor_;
+ MockSession session_;
+};
+
+TEST_F(RunClassificationTest, Basic) {
+ input_tensor_ = test::AsTensor<int>({99});
+ session_.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({2})};
+ const Status status = RunClassification(signature_, input_tensor_, &session_,
+ &classes_tensor_, &scores_tensor_);
+
+ // Validate outputs.
+ TF_ASSERT_OK(status);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
+
+ // Validate inputs.
+ ASSERT_EQ(1, session_.inputs.size());
+ EXPECT_EQ(kInputName, session_.inputs[0].first);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
+ session_.inputs[0].second);
+
+ ASSERT_EQ(2, session_.output_tensor_names.size());
+ EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
+ EXPECT_EQ(kScoresName, session_.output_tensor_names[1]);
+}
+
+TEST_F(RunClassificationTest, ClassesOnly) {
+ input_tensor_ = test::AsTensor<int>({99});
+ session_.outputs = {test::AsTensor<int>({3})};
+ const Status status = RunClassification(signature_, input_tensor_, &session_,
+ &classes_tensor_, nullptr);
+
+ // Validate outputs.
+ TF_ASSERT_OK(status);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({3}), classes_tensor_);
+
+ // Validate inputs.
+ ASSERT_EQ(1, session_.inputs.size());
+ EXPECT_EQ(kInputName, session_.inputs[0].first);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
+ session_.inputs[0].second);
+
+ ASSERT_EQ(1, session_.output_tensor_names.size());
+ EXPECT_EQ(kClassesName, session_.output_tensor_names[0]);
+}
+
+TEST_F(RunClassificationTest, ScoresOnly) {
+ input_tensor_ = test::AsTensor<int>({99});
+ session_.outputs = {test::AsTensor<int>({2})};
+ const Status status = RunClassification(signature_, input_tensor_, &session_,
+ nullptr, &scores_tensor_);
+
+ // Validate outputs.
+ TF_ASSERT_OK(status);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({2}), scores_tensor_);
+
+ // Validate inputs.
+ ASSERT_EQ(1, session_.inputs.size());
+ EXPECT_EQ(kInputName, session_.inputs[0].first);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({99}),
+ session_.inputs[0].second);
+
+ ASSERT_EQ(1, session_.output_tensor_names.size());
+ EXPECT_EQ(kScoresName, session_.output_tensor_names[0]);
+}
+
+TEST(RunClassification, RunNotOk) {
+ ClassificationSignature signature;
+ signature.mutable_input()->set_tensor_name("in:0");
+ signature.mutable_classes()->set_tensor_name("classes:0");
+ Tensor input_tensor = test::AsTensor<int>({99});
+ MockSession session;
+ session.status = errors::DataLoss("Data is gone");
+ Tensor classes_tensor;
+ const Status status = RunClassification(signature, input_tensor, &session,
+ &classes_tensor, nullptr);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
+ << status.error_message();
+}
+
+TEST(RunClassification, TooManyOutputs) {
+ ClassificationSignature signature;
+ signature.mutable_input()->set_tensor_name("in:0");
+ signature.mutable_classes()->set_tensor_name("classes:0");
+ Tensor input_tensor = test::AsTensor<int>({99});
+ MockSession session;
+ session.outputs = {test::AsTensor<int>({3}), test::AsTensor<int>({4})};
+
+ Tensor classes_tensor;
+ const Status status = RunClassification(signature, input_tensor, &session,
+ &classes_tensor, nullptr);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("Expected 1 output"))
+ << status.error_message();
+}
+
+TEST(RunClassification, WrongBatchOutputs) {
+ ClassificationSignature signature;
+ signature.mutable_input()->set_tensor_name("in:0");
+ signature.mutable_classes()->set_tensor_name("classes:0");
+ Tensor input_tensor = test::AsTensor<int>({99, 100});
+ MockSession session;
+ session.outputs = {test::AsTensor<int>({3})};
+
+ Tensor classes_tensor;
+ const Status status = RunClassification(signature, input_tensor, &session,
+ &classes_tensor, nullptr);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Input batch size did not match output batch size"))
+ << status.error_message();
+}
+
+constexpr char kRegressionsName[] = "regressions:0";
+
+class RunRegressionTest : public ::testing::Test {
+ public:
+ void SetUp() override {
+ signature_.mutable_input()->set_tensor_name(kInputName);
+ signature_.mutable_output()->set_tensor_name(kRegressionsName);
+ }
+
+ protected:
+ RegressionSignature signature_;
+ Tensor input_tensor_;
+ Tensor output_tensor_;
+ MockSession session_;
+};
+
+TEST_F(RunRegressionTest, Basic) {
+ input_tensor_ = test::AsTensor<int>({99, 100});
+ session_.outputs = {test::AsTensor<float>({1, 2})};
+ const Status status =
+ RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
+
+ // Validate outputs.
+ TF_ASSERT_OK(status);
+ test::ExpectTensorEqual<float>(test::AsTensor<float>({1, 2}), output_tensor_);
+
+ // Validate inputs.
+ ASSERT_EQ(1, session_.inputs.size());
+ EXPECT_EQ(kInputName, session_.inputs[0].first);
+ test::ExpectTensorEqual<int>(test::AsTensor<int>({99, 100}),
+ session_.inputs[0].second);
+
+ ASSERT_EQ(1, session_.output_tensor_names.size());
+ EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]);
+}
+
+TEST_F(RunRegressionTest, RunNotOk) {
+ input_tensor_ = test::AsTensor<int>({99});
+ session_.status = errors::DataLoss("Data is gone");
+ const Status status =
+ RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone"))
+ << status.error_message();
+}
+
+TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) {
+ input_tensor_ = test::AsTensor<int>({99, 100});
+ session_.outputs = {test::AsTensor<float>({3})};
+
+ const Status status =
+ RunRegression(signature_, input_tensor_, &session_, &output_tensor_);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Input batch size did not match output batch size"))
+ << status.error_message();
+}
+
+TEST(SetAndGetSignatures, RoundTrip) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_default_signature()
+ ->mutable_classification_signature()
+ ->mutable_input()
+ ->set_tensor_name("in:0");
+ TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def));
+ Signatures read_signatures;
+ TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures));
+
+ EXPECT_EQ("in:0", read_signatures.default_signature()
+ .classification_signature()
+ .input()
+ .tensor_name());
+}
+
+// GenericSignature test fixture that contains a signature initialized with two
+// bound Tensors.
+class GenericSignatureTest : public ::testing::Test {
+ protected:
+ GenericSignatureTest() {
+ TensorBinding binding;
+ binding.set_tensor_name("graph_A");
+ signature_.mutable_map()->insert({"logical_A", binding});
+
+ binding.set_tensor_name("graph_B");
+ signature_.mutable_map()->insert({"logical_B", binding});
+ }
+
+ // GenericSignature that contains two bound Tensors.
+ GenericSignature signature_;
+};
+
+// GenericSignature tests.
+
+TEST_F(GenericSignatureTest, GetGenericSignatureBasic) {
+ Signature expected_signature;
+ expected_signature.mutable_generic_signature()->MergeFrom(signature_);
+
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ signatures.mutable_named_signatures()->insert(
+ {"generic_bindings", expected_signature});
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ GenericSignature actual_signature;
+ TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def,
+ &actual_signature));
+ ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name());
+ ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name());
+}
+
+TEST(GetGenericSignature, MissingSignature) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ GenericSignature signature;
+ const Status status =
+ GetGenericSignature("generic_bindings", meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(HasSubstr(status.error_message(),
+ "Missing generic signature named \"generic_bindings\""))
+ << status.error_message();
+}
+
+TEST(GetGenericSignature, WrongSignatureType) {
+ tensorflow::MetaGraphDef meta_graph_def;
+ Signatures signatures;
+ (*signatures.mutable_named_signatures())["generic_bindings"]
+ .mutable_regression_signature();
+ (*meta_graph_def.mutable_collection_def())[kSignaturesKey]
+ .mutable_any_list()
+ ->add_value()
+ ->PackFrom(signatures);
+
+ GenericSignature signature;
+ const Status status =
+ GetGenericSignature("generic_bindings", meta_graph_def, &signature);
+ ASSERT_FALSE(status.ok());
+ EXPECT_TRUE(StringPiece(status.error_message())
+ .contains("Expected a generic signature:"))
+ << status.error_message();
+}
+
+// BindGeneric Tests.
+
+TEST_F(GenericSignatureTest, BindGenericInputsBasic) {
+ const std::vector<std::pair<string, Tensor>> inputs = {
+ {"logical_A", test::AsTensor<float>({-1.0})},
+ {"logical_B", test::AsTensor<float>({-2.0})}};
+
+ std::vector<std::pair<string, Tensor>> bound_inputs;
+ TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs));
+
+ EXPECT_EQ("graph_A", bound_inputs[0].first);
+ EXPECT_EQ("graph_B", bound_inputs[1].first);
+ test::ExpectTensorEqual<float>(test::AsTensor<float>({-1.0}),
+ bound_inputs[0].second);
+ test::ExpectTensorEqual<float>(test::AsTensor<float>({-2.0}),
+ bound_inputs[1].second);
+}
+
+TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) {
+ const std::vector<std::pair<string, Tensor>> inputs = {
+ {"logical_A", test::AsTensor<float>({-42.0})},
+ {"logical_MISSING", test::AsTensor<float>({-43.0})}};
+
+ std::vector<std::pair<string, Tensor>> bound_inputs;
+ const Status status = BindGenericInputs(signature_, inputs, &bound_inputs);
+ ASSERT_FALSE(status.ok());
+}
+
+TEST_F(GenericSignatureTest, BindGenericNamesBasic) {
+ const std::vector<string> input_names = {"logical_B", "logical_A"};
+ std::vector<string> bound_names;
+ TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names));
+
+ EXPECT_EQ("graph_B", bound_names[0]);
+ EXPECT_EQ("graph_A", bound_names[1]);
+}
+
+TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) {
+ const std::vector<string> input_names = {"logical_B", "logical_MISSING"};
+ std::vector<string> bound_names;
+ const Status status = BindGenericNames(signature_, input_names, &bound_names);
+ ASSERT_FALSE(status.ok());
+}
+
+} // namespace
+} // namespace contrib
+} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/test_util.cc b/tensorflow/contrib/session_bundle/test_util.cc
new file mode 100644
index 0000000000..90870a69a0
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/test_util.cc
@@ -0,0 +1,35 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/session_bundle/test_util.h"
+
+#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace test_util {
+
+string TestSrcDirPath(const string& relative_path) {
+ const string base_path = tensorflow::testing::TensorFlowSrcRoot();
+ const string contrib_path = tensorflow::io::JoinPath(
+ tensorflow::testing::TensorFlowSrcRoot(), "/contrib");
+ return tensorflow::io::JoinPath(contrib_path, relative_path);
+}
+
+} // namespace test_util
+} // namespace contrib
+} // namespace tensorflow
diff --git a/tensorflow/contrib/session_bundle/test_util.h b/tensorflow/contrib/session_bundle/test_util.h
new file mode 100644
index 0000000000..87ddf28d60
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/test_util.h
@@ -0,0 +1,38 @@
+/* Copyright 2016 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
+#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
+
+#include <string>
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace contrib {
+namespace test_util {
+
+// Creates an absolute test srcdir path to the linked in runfiles given a path
+// relative to third_party/tensorflow/contrib/.
+// e.g. relative path = "session_bundle/example".
+string TestSrcDirPath(const string& relative_path);
+
+} // namespace test_util
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index a14bd88f69..a581923acf 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -30,6 +30,9 @@ sh_binary(
":other_headers",
":simple_console",
"//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/session_bundle:all_files",
+ "//tensorflow/contrib/session_bundle:manifest_proto_py",
+ "//tensorflow/contrib/session_bundle/example:half_plus_two",
"//tensorflow/contrib/slim:all_files",
"//tensorflow/core:framework_headers",
"//tensorflow/examples/tutorials/mnist:package",
diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template
index d3e70e7a4f..d4dddb5211 100644
--- a/tools/bazel.rc.template
+++ b/tools/bazel.rc.template
@@ -6,6 +6,12 @@ build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY
build --define=use_fast_cpp_protos=true
build --define=allow_oversize_protos=true
+build --define PYTHON_BIN_PATH=$PYTHON_BINARY
+test --define PYTHON_BIN_PATH=$PYTHON_BINARY
+test --force_python=py$PYTHON_MAJOR_VERSION
+test --host_force_python=py$PYTHON_MAJOR_VERSION
+run --define PYTHON_BIN_PATH=$PYTHON_BINARY
+
build --spawn_strategy=standalone
test --spawn_strategy=standalone
run --spawn_strategy=standalone