From aaefa058f5fbb7a08834861643ab28ad00704683 Mon Sep 17 00:00:00 2001 From: Noah Fiedel Date: Fri, 3 Jun 2016 08:26:32 -0800 Subject: Phase 1 of moving TensorFlow Serving SessionBundle to tensorflow/contrib. Change: 123975418 --- tensorflow/BUILD | 2 + tensorflow/contrib/session_bundle/BUILD | 166 ++++++ tensorflow/contrib/session_bundle/README.md | 243 +++++++++ tensorflow/contrib/session_bundle/example/BUILD | 52 ++ .../session_bundle/example/export_half_plus_two.py | 115 ++++ tensorflow/contrib/session_bundle/exporter.py | 311 +++++++++++ tensorflow/contrib/session_bundle/exporter_test.py | 218 ++++++++ tensorflow/contrib/session_bundle/gc.py | 204 +++++++ tensorflow/contrib/session_bundle/gc_test.py | 115 ++++ tensorflow/contrib/session_bundle/manifest.proto | 70 +++ .../contrib/session_bundle/session_bundle.cc | 181 +++++++ tensorflow/contrib/session_bundle/session_bundle.h | 59 ++ .../contrib/session_bundle/session_bundle_test.cc | 102 ++++ tensorflow/contrib/session_bundle/signature.cc | 270 +++++++++ tensorflow/contrib/session_bundle/signature.h | 123 +++++ .../contrib/session_bundle/signature_test.cc | 602 +++++++++++++++++++++ tensorflow/contrib/session_bundle/test_util.cc | 35 ++ tensorflow/contrib/session_bundle/test_util.h | 38 ++ tensorflow/tools/pip_package/BUILD | 3 + tools/bazel.rc.template | 6 + 20 files changed, 2915 insertions(+) create mode 100644 tensorflow/contrib/session_bundle/BUILD create mode 100644 tensorflow/contrib/session_bundle/README.md create mode 100644 tensorflow/contrib/session_bundle/example/BUILD create mode 100644 tensorflow/contrib/session_bundle/example/export_half_plus_two.py create mode 100644 tensorflow/contrib/session_bundle/exporter.py create mode 100644 tensorflow/contrib/session_bundle/exporter_test.py create mode 100644 tensorflow/contrib/session_bundle/gc.py create mode 100644 tensorflow/contrib/session_bundle/gc_test.py create mode 100644 tensorflow/contrib/session_bundle/manifest.proto create mode 100644 tensorflow/contrib/session_bundle/session_bundle.cc create mode 100644 tensorflow/contrib/session_bundle/session_bundle.h create mode 100644 tensorflow/contrib/session_bundle/session_bundle_test.cc create mode 100644 tensorflow/contrib/session_bundle/signature.cc create mode 100644 tensorflow/contrib/session_bundle/signature.h create mode 100644 tensorflow/contrib/session_bundle/signature_test.cc create mode 100644 tensorflow/contrib/session_bundle/test_util.cc create mode 100644 tensorflow/contrib/session_bundle/test_util.h diff --git a/tensorflow/BUILD b/tensorflow/BUILD index a8153c06b6..9e5920efc4 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -84,6 +84,8 @@ filegroup( "//tensorflow/contrib/quantization:all_files", "//tensorflow/contrib/quantization/kernels:all_files", "//tensorflow/contrib/quantization/tools:all_files", + "//tensorflow/contrib/session_bundle:all_files", + "//tensorflow/contrib/session_bundle/example:all_files", "//tensorflow/contrib/skflow:all_files", "//tensorflow/contrib/slim:all_files", "//tensorflow/contrib/tensor_forest:all_files", diff --git a/tensorflow/contrib/session_bundle/BUILD b/tensorflow/contrib/session_bundle/BUILD new file mode 100644 index 0000000000..19a10b0fc0 --- /dev/null +++ b/tensorflow/contrib/session_bundle/BUILD @@ -0,0 +1,166 @@ +# Description: Tensorflow Serving session bundle. + +package( + default_visibility = ["//visibility:public"], + features = [ + "-layering_check", + ], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "g3doc/sitemap.md", + ], + ), +) + +load("//tensorflow/core:platform/default/build_config.bzl", "tf_proto_library") + +py_library( + name = "exporter", + srcs = ["exporter.py"], + srcs_version = "PY2AND3", + deps = [ + ":gc", + ":manifest_proto_py", + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "exporter_test", + size = "small", + srcs = [ + "exporter_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":exporter", + ":gc", + ":manifest_proto_py", + "//tensorflow:tensorflow_py", + ], +) + +py_library( + name = "gc", + srcs = ["gc.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + ], +) + +py_test( + name = "gc_test", + srcs = [ + "gc_test.py", + ], + srcs_version = "PY2AND3", + visibility = ["//visibility:private"], + deps = [ + ":gc", + "//tensorflow:tensorflow_py", + ], +) + +cc_library( + name = "session_bundle", + srcs = ["session_bundle.cc"], + hdrs = ["session_bundle.h"], + deps = [ + ":manifest_proto_cc", + ":signature", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + ], +) + +cc_test( + name = "session_bundle_test", + size = "small", + srcs = ["session_bundle_test.cc"], + data = [ + "//tensorflow/contrib/session_bundle/example:half_plus_two", + ], + # Link in all registered kernels. + linkstatic = 1, + visibility = ["//visibility:private"], + deps = [ + ":session_bundle", + ":test_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "signature", + srcs = ["signature.cc"], + hdrs = ["signature.h"], + deps = [ + ":manifest_proto_cc", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:tensorflow_opensource", + ], +) + +cc_test( + name = "signature_test", + size = "small", + srcs = ["signature_test.cc"], + visibility = ["//visibility:private"], + deps = [ + ":manifest_proto_cc", + ":signature", + ":test_util", + "//tensorflow/core:core_cpu", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:tensorflow_opensource", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core:testlib", + ], +) + +cc_library( + name = "test_util", + testonly = 1, + srcs = ["test_util.cc"], + hdrs = ["test_util.h"], + visibility = ["//visibility:private"], + deps = [ + "//tensorflow/core", + "//tensorflow/core:lib", + "//tensorflow/core:test", + ], +) + +tf_proto_library( + name = "manifest_proto", + srcs = ["manifest.proto"], + cc_api_version = 2, + py_api_version = 2, + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/session_bundle/README.md b/tensorflow/contrib/session_bundle/README.md new file mode 100644 index 0000000000..d9b6c0f9a2 --- /dev/null +++ b/tensorflow/contrib/session_bundle/README.md @@ -0,0 +1,243 @@ +# TensorFlow Inference Model Format + +[TOC] + +## Overview + +This document describes the data formats and layouts for exporting [TensorFlow] +(https://www.tensorflow.org/) models for inference. + +These exports have the following properties, + +* Recoverable + * given an export the graph can easily be initialized and run +* Hermetic + * an export directory is self-contained to facilitate distribution + +## Directory Structure + +~~~ +# Directory overview +00000000/ + assets/ + export.meta + export-?????-of-????? +~~~ + +* `00000000` Export version + * Format `%08d` +* `assets` Asset file directory + * Holds auxiliary files for the graph (e.g., vocabularies) +* `export.meta` MetaGraph Definition + * Binary [`tensorflow::MetaGraphDef`] + (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/core/protobuf/meta_graph.proto) +* `export-?????-of-?????` + * Graph Variables + * Outputs from Python [`Saver`] + (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/training/saver.py) + with `sharded=True`. + +## Python exporting code + +The [`Exporter`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/exporter.py) +class can be used to export a model in the above format from a Tensorflow python +binary. + +## C++ initialization code + +The [`LoadSessionBundleFromPath`] +(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/session_bundle.h) +function can be used to create a `tensorflow::Session` and initialize it from an +export. This function takes options and the path to the export and returns a +bundle of export data including a `tensorflow::Session` which can be run. + +## Signatures + +Graphs used for inference tasks typically have set of inputs and outputs used +at inference time. We call this a signature. + +### Standard Signatures (standard usage) + +Graphs used for standard inference tasks have standard set of inputs and +outputs. For example, a graph used for a regression task has an input tensor for +the data and an output tensor for the regression values. The signature mechanism +makes it easy to identify the relevant input and output tensors for common graph +applications. + +The [`Manifest`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/manifest.proto) +contains a `Signature` message which contains the task specific inputs and +outputs. + +~~~ +// A Signature specifies the inputs and outputs of commonly used graphs. +message Signature { + oneof type { + RegressionSignature regression_signature = 1; + ClassificationSignature classification_signature = 2; + GenericSignature generic_signature = 3; + } +}; +~~~ + +Standard signature can be set at export time using the `Exporter` API + +~~~python +# Run an export. +signature = exporter.classification_signature(input_tensor=input, + classes_tensor=output) +export = exporter.Exporter(saver) +export.init(sess.graph.as_graph_def(), + default_graph_signature=signature) +export.export(export_path, + global_step_tensor, + sess) +~~~ + +These can be recovered at serving time using utilities in [`signature.h`] +(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h) + +~~~c++ +// Get the a classification signature. +ClassificationSignature signature; +TF_CHECK_OK(GetClassificationSignature(bundle->meta_graph_def, &signature)); + +// Run the graph. +Tensor input_tensor = GetInputTensor(); +Tensor classes_tensor; +Tensor scores_tensor; +TF_CHECK_OK(RunClassification(signature, input_tensor, session, + &classes_tensor, &scores_tensor)); +~~~ + +### Generic Signatures (custom or advanced usage) + +Generic Signatures enable fully custom usage of the `tensorflow::Session` API. +They are recommended for when the standard Signatures do not satisfy a +particular use-case. A general example of when to use these is for a model +taking a single input and generating multiple outputs performing different +inferences. + +~~~ +// GenericSignature specifies a map from logical name to Tensor name. +// Typical application of GenericSignature is to use a single GenericSignature +// that includes all of the Tensor nodes and target names that may be useful at +// serving, analysis or debugging time. The recommended name for this signature +// is "generic_bindings". +message GenericSignature { + map map = 1; +}; +~~~ + +Generic Signatures can be used to compliment a standard signature, for example +to support debugging. Here is an example usage including both the standard +regression signature and a generic signature. + +~~~python +named_tensor_bindings = {"logical_input_A": v0, + "logical_input_B": v1} +signatures = { + "regression": exporter.regression_signature(input_tensor=v0, + output_tensor=v1), + "generic": exporter.generic_signature(named_tensor_bindings)} +export = exporter.Exporter(saver) +export.init(sess.graph.as_graph_def(), + named_graph_signature=signatures) +export.export(export_path, + global_step_tensor, + sess) +~~~ + +Generic Signature does not differentiate between input and output tensors. It +provides full flexibility to specify the input & output tensors you need. +The benefit is preserving a mapping between names that you specify at export +time (we call these the logical names), and the actual graph node names that may +be less stable and/or auto-generated by TensorFlow. + +In [`signature.h`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/session_bundle/signature.h), +note that the generic signature methods BindGenericInputs and BindGenericNames +are doing simple string to string mapping as a convenience. These methods map +from the names used at training time to actual names in the graph. Use the bound +results from those methods, e.g. `vector>` and +`vector` respectively, as inputs to`tensorflow::Session->Run()`. For +`Session->Run()`, map these into the first two parameters, `inputs` and +`output_tensor_names` respectively. The next param, `target_node_names` is +typically null at inference time. The last param outputs is for the results in +the same order of your `output_tensor_names`. + +## Initialization + +Some graphs many require custom initialization after the variables have been +restored. Such initialization, done through an arbitrary Op, can be added using +the `Exporter` API. If set, `LoadSessionBundleFromPath` will automatically run +the Op when restoring a `Session` following the loading of variables. + +## Assets + +In many cases we have Ops which depend on external files for initialization +(such as vocabularies). These "assets" are not stored in the graph and are +needed for both training and inference. + +In order to create hermetic exports these asset files need to be 1) copied to +each export directory and 2) read when recovering a session from an export base +directory. + +Copying assets to the export dir is handled with a callback mechanism. +The callback function receives two parameters 1) the dictionary of source files +to desired basename and 2) the export directory. The default callback uses +`gfile.Copy` to perform the copy. + +The tensors that contains the filepath to be copied and be replaced for +inference in specified by passing the collection of asset filepath tensor, +which is usually extracted from the graph by `tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)`. + +~~~python + # Run an export. + export = exporter.Exporter(save) + export.init( + sess.graph.as_graph_def(), + asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS)) + export.export(export_path, global_step_tensor, sess) +~~~ + +Users can use their own callbacks as shown in the following example, with the +requirement to keep the basename of the original files: + +~~~python + def my_custom_copy_callback(files_to_copy, export_dir_path): + # Copy all source files (keys) in files_to_copy to export_dir_path + # using the corresponging basename (value). + ... + + + # Run an export. + export = exporter.Exporter(save) + export.init( + sess.graph.as_graph_def(), + asset_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), + asset_callback=my_custom_copy_callback) + export.export(export_path, global_step_tensor, sess) +~~~ + + +`AssetFile` binds the name of a tensor in the graph to the name of a file +within the assets directory. `LoadSessionBundleFromPath` will handle the base +path and asset directory swap/concatenation such that the tensor is set with +the fully qualified filename upon return. + +# Notes of exporter usage + +The typical workflow of model exporting is: + +1. Build model graph G +2. Train variables or load trained variables from checkpoint in session S +3. [Optional] build inference graph I +4. Export G + +The Exporter should be used as follows: + +1. The Saver used in Exporter(saver) should be created under the context of G +2. Exporter.init() should be called under the context of G +3. Exporter.export() should be called using session S +4. If I is provided for Exporter.init(), an exact same Saver should be created + under I as the saver under G -- in the way that exact same Save/Restore ops + are created in both G and S diff --git a/tensorflow/contrib/session_bundle/example/BUILD b/tensorflow/contrib/session_bundle/example/BUILD new file mode 100644 index 0000000000..8fa0c0b020 --- /dev/null +++ b/tensorflow/contrib/session_bundle/example/BUILD @@ -0,0 +1,52 @@ +# Description: Tensorflow Serving session_bundle example. + +package( + default_visibility = ["//tensorflow/contrib/session_bundle:__subpackages__"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +# vardef("PYTHON_BIN_PATH", "/usr/bin/python") + +filegroup( + name = "all_files", + srcs = glob( + ["**/*"], + exclude = [ + "**/METADATA", + "**/OWNERS", + "g3doc/sitemap.md", + ], + ), + visibility = ["//visibility:public"], +) + +py_binary( + name = "export_half_plus_two", + srcs = [ + "export_half_plus_two.py", + ], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow:tensorflow_py", + "//tensorflow/contrib/session_bundle:exporter", + ], +) + +genrule( + name = "half_plus_two", + outs = [ + "half_plus_two/00000123/export.meta", + "half_plus_two/00000123/export-00000-of-00001", + ], + cmd = + "rm -rf /tmp/half_plus_two; " + + "$(PYTHON_BIN_PATH) $(locations :export_half_plus_two); " + + "cp -r /tmp/half_plus_two/* $(@D)/half_plus_two", + tools = [ + ":export_half_plus_two", + ], + visibility = ["//visibility:public"], +) diff --git a/tensorflow/contrib/session_bundle/example/export_half_plus_two.py b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py new file mode 100644 index 0000000000..e4b1947e03 --- /dev/null +++ b/tensorflow/contrib/session_bundle/example/export_half_plus_two.py @@ -0,0 +1,115 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Exports a toy linear regression inference graph. + +Exports a TensorFlow graph to /tmp/half_plus_two/ based on the Exporter +format, go/tf-exporter. + +This graph calculates, + y = a*x + b +where a and b are variables with a=0.5 and b=2. + +Output from this program is typically used to exercise Session +loading and execution code. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +from tensorflow.contrib.session_bundle import exporter + + +def Export(): + export_path = "/tmp/half_plus_two" + with tf.Session() as sess: + # Make model parameters a&b variables instead of constants to + # exercise the variable reloading mechanisms. + a = tf.Variable(0.5, name="a") + b = tf.Variable(2.0, name="b") + + # Calculate, y = a*x + b + # here we use a placeholder 'x' which is fed at inference time. + x = tf.placeholder(tf.float32, name="x") + y = tf.add(tf.mul(a, x), b, name="y") + + # Setup a standard Saver for our variables. + save = tf.train.Saver({"a": a, "b": b}, sharded=True) + + # asset_path contains the base directory of assets used in training (e.g. + # vocabulary files). + original_asset_path = tf.constant("/tmp/original/export/assets") + # Ops reading asset files should reference the asset_path tensor + # which stores the original asset path at training time and the + # overridden assets directory at restore time. + asset_path = tf.Variable(original_asset_path, + name="asset_path", + trainable=False, + collections=[]) + assign_asset_path = asset_path.assign(original_asset_path) + + # Use a fixed global step number. + global_step_tensor = tf.Variable(123, name="global_step") + + # Create a RegressionSignature for our input and output. + signature = exporter.regression_signature(input_tensor=x, output_tensor=y) + + # Create two filename assets and corresponding tensors. + # TODO(b/26254158) Consider adding validation of file existance as well as + # hashes (e.g. sha1) for consistency. + original_filename1 = tf.constant("hello1.txt") + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename1) + filename1 = tf.Variable(original_filename1, + name="filename1", + trainable=False, + collections=[]) + assign_filename1 = filename1.assign(original_filename1) + original_filename2 = tf.constant("hello2.txt") + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, original_filename2) + filename2 = tf.Variable(original_filename2, + name="filename2", + trainable=False, + collections=[]) + assign_filename2 = filename2.assign(original_filename2) + + # Init op contains a group of all variables that we assign. + init_op = tf.group(assign_asset_path, assign_filename1, assign_filename2) + + # CopyAssets is used as a callback during export to copy files to the + # given export directory. + def CopyAssets(filepaths, export_path): + print("copying asset files to: %s" % export_path) + for filepath in filepaths: + print("copying asset file: %s" % filepath) + + # Run an export. + tf.initialize_all_variables().run() + export = exporter.Exporter(save) + export.init( + sess.graph.as_graph_def(), + init_op=init_op, + default_graph_signature=signature, + assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), + assets_callback=CopyAssets) + export.export(export_path, global_step_tensor, sess) + + +def main(_): + Export() + + +if __name__ == "__main__": + tf.app.run() diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py new file mode 100644 index 0000000000..fc7458db62 --- /dev/null +++ b/tensorflow/contrib/session_bundle/exporter.py @@ -0,0 +1,311 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Export a TensorFlow model. + +See: go/tf-exporter +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re +import six + +import tensorflow as tf +from google.protobuf.any_pb2 import Any + +from tensorflow.contrib.session_bundle import gc +from tensorflow.contrib.session_bundle import manifest_pb2 +from tensorflow.python.framework import ops +from tensorflow.python.platform import gfile +from tensorflow.python.training import training_util +from tensorflow.python.util import compat + +# See: go/tf-exporter for these constants and directory structure. +VERSION_FORMAT_SPECIFIER = "%08d" +ASSETS_DIRECTORY = "assets" +EXPORT_BASE_NAME = "export" +EXPORT_SUFFIX_NAME = "meta" +META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME +VARIABLES_FILENAME = EXPORT_BASE_NAME +VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????" +INIT_OP_KEY = "serving_init_op" +SIGNATURES_KEY = "serving_signatures" +ASSETS_KEY = "serving_assets" +GRAPH_KEY = "serving_graph" + + +def gfile_copy_callback(files_to_copy, export_dir_path): + """Callback to copy files using `gfile.Copy` to an export directory. + + This method is used as the default `assets_callback` in `Exporter.init` to + copy assets from the `assets_collection`. It can also be invoked directly to + copy additional supplementary files into the export directory (in which case + it is not a callback). + + Args: + files_to_copy: A dictionary that maps original file paths to desired + basename in the export directory. + export_dir_path: Directory to copy the files to. + """ + tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path) + gfile.MakeDirs(export_dir_path) + for source_filepath, basename in files_to_copy.items(): + new_path = os.path.join( + compat.as_bytes(export_dir_path), compat.as_bytes(basename)) + tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path) + + if gfile.Exists(new_path): + # Guard against being restarted while copying assets, and the file + # existing and being in an unknown state. + # TODO(b/28676216): Do some file checks before deleting. + tf.logging.info("Removing file %s.", new_path) + gfile.Remove(new_path) + tf.gfile.Copy(source_filepath, new_path) + + +def regression_signature(input_tensor, output_tensor): + """Creates a regression signature. + + Args: + input_tensor: Tensor specifying the input to a graph. + output_tensor: Tensor specifying the output of a graph. + + Returns: + A Signature message. + """ + signature = manifest_pb2.Signature() + signature.regression_signature.input.tensor_name = input_tensor.name + signature.regression_signature.output.tensor_name = output_tensor.name + return signature + + +def classification_signature(input_tensor, + classes_tensor=None, + scores_tensor=None): + """Creates a classification signature. + + Args: + input_tensor: Tensor specifying the input to a graph. + classes_tensor: Tensor specifying the output classes of a graph. + scores_tensor: Tensor specifying the scores of the output classes. + + Returns: + A Signature message. + """ + signature = manifest_pb2.Signature() + signature.classification_signature.input.tensor_name = input_tensor.name + if classes_tensor is not None: + signature.classification_signature.classes.tensor_name = classes_tensor.name + if scores_tensor is not None: + signature.classification_signature.scores.tensor_name = scores_tensor.name + return signature + + +def generic_signature(name_tensor_map): + """Creates a generic signature of name to Tensor name. + + Args: + name_tensor_map: Map from logical name to Tensor. + + Returns: + A Signature message. + """ + signature = manifest_pb2.Signature() + for name, tensor in six.iteritems(name_tensor_map): + signature.generic_signature.map[name].tensor_name = tensor.name + return signature + + +class Exporter(object): + """Exporter helps package a TensorFlow model for serving. + + Args: + saver: Saver object. + """ + + def __init__(self, saver): + self._saver = saver + self._has_init = False + self._assets_to_copy = {} + + def init(self, + graph_def=None, + init_op=None, + clear_devices=False, + default_graph_signature=None, + named_graph_signatures=None, + assets_collection=None, + assets_callback=gfile_copy_callback): + """Initialization. + + Args: + graph_def: A GraphDef message of the graph to be used in inference. + GraphDef of default graph is used when None. + init_op: Op to be used in initialization. + clear_devices: If device info of the graph should be cleared upon export. + default_graph_signature: Default signature of the graph. + named_graph_signatures: Map of named input/output signatures of the graph. + assets_collection: A collection of constant asset filepath tensors. If set + the assets will be exported into the asset directory. + assets_callback: callback with two argument called during export with the + list of files to copy and the asset path. + Raises: + RuntimeError: if init is called more than once. + TypeError: if init_op is not an Operation or None. + ValueError: if asset file path tensors are not non-empty constant string + scalar tensors. + """ + # Avoid Dangerous default value [] + if named_graph_signatures is None: + named_graph_signatures = {} + assets = [] + if assets_collection: + for asset_tensor in assets_collection: + asset_filepath = self._file_path_value(asset_tensor) + if not asset_filepath: + raise ValueError("invalid asset filepath tensor %s" % asset_tensor) + basename = os.path.basename(asset_filepath) + assets.append((basename, asset_tensor)) + self._assets_to_copy[asset_filepath] = basename + + if self._has_init: + raise RuntimeError("init should be called only once") + self._has_init = True + + if graph_def or clear_devices: + copy = tf.GraphDef() + if graph_def: + copy.CopyFrom(graph_def) + else: + copy.CopyFrom(tf.get_default_graph().as_graph_def()) + if clear_devices: + for node in copy.node: + node.device = "" + graph_any_buf = Any() + graph_any_buf.Pack(copy) + tf.add_to_collection(GRAPH_KEY, graph_any_buf) + + if init_op: + if not isinstance(init_op, ops.Operation): + raise TypeError("init_op needs to be an Operation: %s" % init_op) + tf.add_to_collection(INIT_OP_KEY, init_op) + + signatures_proto = manifest_pb2.Signatures() + if default_graph_signature: + signatures_proto.default_signature.CopyFrom(default_graph_signature) + for signature_name, signature in six.iteritems(named_graph_signatures): + signatures_proto.named_signatures[signature_name].CopyFrom(signature) + signatures_any_buf = Any() + signatures_any_buf.Pack(signatures_proto) + tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf) + + for filename, tensor in assets: + asset = manifest_pb2.AssetFile() + asset.filename = filename + asset.tensor_binding.tensor_name = tensor.name + asset_any_buf = Any() + asset_any_buf.Pack(asset) + tf.add_to_collection(ASSETS_KEY, asset_any_buf) + + self._assets_callback = assets_callback + + def export(self, + export_dir_base, + global_step_tensor, + sess=None, + exports_to_keep=None): + """Exports the model. + + Args: + export_dir_base: A string path to the base export dir. + global_step_tensor: An Tensor or tensor name providing the + global step counter to append to the export directory path and set + in the manifest version. + sess: A Session to use to save the parameters. + exports_to_keep: a gc.Path filter function used to determine the set of + exports to keep. If set to None, all versions will be kept. + + Returns: + The string path to the exported directory. + + Raises: + RuntimeError: if init is not called. + RuntimeError: if the export would overwrite an existing directory. + """ + if not self._has_init: + raise RuntimeError("init must be called first") + + global_step = training_util.global_step(sess, global_step_tensor) + export_dir = os.path.join( + compat.as_bytes(export_dir_base), + compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step)) + + # Prevent overwriting on existing exports which could lead to bad/corrupt + # storage and loading of models. This is an important check that must be + # done before any output files or directories are created. + if gfile.Exists(export_dir): + raise RuntimeError("Overwriting exports can cause corruption and are " + "not allowed. Duplicate export dir: %s" % export_dir) + + # Output to a temporary directory which is atomically renamed to the final + # directory when complete. + tmp_export_dir = compat.as_text(export_dir) + "-tmp" + gfile.MakeDirs(tmp_export_dir) + + self._saver.save(sess, + os.path.join( + compat.as_text(tmp_export_dir), + compat.as_text(EXPORT_BASE_NAME)), + meta_graph_suffix=EXPORT_SUFFIX_NAME) + + # Run the asset callback. + if self._assets_callback and self._assets_to_copy: + assets_dir = os.path.join( + compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY)) + gfile.MakeDirs(assets_dir) + self._assets_callback(self._assets_to_copy, assets_dir) + + # TODO(b/27794910): Delete *checkpoint* file before rename. + gfile.Rename(tmp_export_dir, export_dir) + + if exports_to_keep: + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + paths_to_delete = gc.negation(exports_to_keep) + for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)): + gfile.DeleteRecursively(p.path) + + return export_dir + + def _file_path_value(self, path_tensor): + """Returns the filepath value stored in constant `path_tensor`.""" + if not isinstance(path_tensor, tf.Tensor): + raise TypeError("tensor is not a Tensor") + if path_tensor.op.type != "Const": + raise TypeError("Only constants tensor are supported") + if path_tensor.dtype != tf.string: + raise TypeError("File paths should be string") + str_value = path_tensor.op.get_attr("value").string_val + if len(str_value) != 1: + raise TypeError("Only scalar tensors are supported") + return str_value[0] diff --git a/tensorflow/contrib/session_bundle/exporter_test.py b/tensorflow/contrib/session_bundle/exporter_test.py new file mode 100644 index 0000000000..8b54933fc9 --- /dev/null +++ b/tensorflow/contrib/session_bundle/exporter_test.py @@ -0,0 +1,218 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for exporter.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os.path + + +import tensorflow as tf + +from tensorflow.contrib.session_bundle import exporter +from tensorflow.contrib.session_bundle import gc +from tensorflow.contrib.session_bundle import manifest_pb2 +from tensorflow.core.protobuf import config_pb2 +from tensorflow.python.platform import flags +from tensorflow.python.platform import gfile + + +FLAGS = flags.FLAGS + +GLOBAL_STEP = 222 + + +def tearDownModule(): + gfile.DeleteRecursively(tf.test.get_temp_dir()) + + +class SaveRestoreShardedTest(tf.test.TestCase): + + def doBasicsOneExportPath(self, + export_path, + clear_devices=False, + global_step=GLOBAL_STEP, + sharded=True): + # Build a graph with 2 parameter nodes on different devices. + tf.reset_default_graph() + with tf.Session( + target="", + config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: + # v2 is an unsaved variable derived from v0 and v1. It is used to + # exercise the ability to run an init op when restoring a graph. + with sess.graph.device("/cpu:0"): + v0 = tf.Variable(10, name="v0") + with sess.graph.device("/cpu:1"): + v1 = tf.Variable(20, name="v1") + v2 = tf.Variable(1, name="v2", trainable=False, collections=[]) + assign_v2 = tf.assign(v2, tf.add(v0, v1)) + init_op = tf.group(assign_v2, name="init_op") + + tf.add_to_collection("v", v0) + tf.add_to_collection("v", v1) + tf.add_to_collection("v", v2) + + global_step_tensor = tf.Variable(global_step, name="global_step") + named_tensor_bindings = {"logical_input_A": v0, "logical_input_B": v1} + signatures = { + "foo": exporter.regression_signature(input_tensor=v0, + output_tensor=v1), + "generic": exporter.generic_signature(named_tensor_bindings) + } + + asset_filepath_orig = os.path.join(tf.test.get_temp_dir(), "hello42.txt") + asset_file = tf.constant(asset_filepath_orig, name="filename42") + tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, asset_file) + + with gfile.FastGFile(asset_filepath_orig, "w") as f: + f.write("your data here") + assets_collection = tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS) + + ignored_asset = os.path.join(tf.test.get_temp_dir(), "ignored.txt") + with gfile.FastGFile(ignored_asset, "w") as f: + f.write("additional data here") + + tf.initialize_all_variables().run() + + # Run an export. + save = tf.train.Saver({"v0": v0, + "v1": v1}, + restore_sequentially=True, + sharded=sharded) + export = exporter.Exporter(save) + export.init(sess.graph.as_graph_def(), + init_op=init_op, + clear_devices=clear_devices, + default_graph_signature=exporter.classification_signature( + input_tensor=v0), + named_graph_signatures=signatures, + assets_collection=assets_collection) + export.export(export_path, + global_step_tensor, + sess, + exports_to_keep=gc.largest_export_versions(2)) + + # Restore graph. + compare_def = tf.get_default_graph().as_graph_def() + tf.reset_default_graph() + with tf.Session( + target="", + config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess: + save = tf.train.import_meta_graph( + os.path.join(export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.META_GRAPH_DEF_FILENAME)) + self.assertIsNotNone(save) + meta_graph_def = save.export_meta_graph() + collection_def = meta_graph_def.collection_def + + # Validate custom graph_def. + graph_def_any = collection_def[exporter.GRAPH_KEY].any_list.value + self.assertEquals(len(graph_def_any), 1) + graph_def = tf.GraphDef() + graph_def_any[0].Unpack(graph_def) + if clear_devices: + for node in compare_def.node: + node.device = "" + self.assertProtoEquals(compare_def, graph_def) + + # Validate init_op. + init_ops = collection_def[exporter.INIT_OP_KEY].node_list.value + self.assertEquals(len(init_ops), 1) + self.assertEquals(init_ops[0], "init_op") + + # Validate signatures. + signatures_any = collection_def[exporter.SIGNATURES_KEY].any_list.value + self.assertEquals(len(signatures_any), 1) + signatures = manifest_pb2.Signatures() + signatures_any[0].Unpack(signatures) + default_signature = signatures.default_signature + self.assertEqual( + default_signature.classification_signature.input.tensor_name, "v0:0") + bindings = signatures.named_signatures["generic"].generic_signature.map + self.assertEquals(bindings["logical_input_A"].tensor_name, "v0:0") + self.assertEquals(bindings["logical_input_B"].tensor_name, "v1:0") + read_foo_signature = ( + signatures.named_signatures["foo"].regression_signature) + self.assertEquals(read_foo_signature.input.tensor_name, "v0:0") + self.assertEquals(read_foo_signature.output.tensor_name, "v1:0") + + # Validate the assets. + assets_any = collection_def[exporter.ASSETS_KEY].any_list.value + self.assertEquals(len(assets_any), 1) + asset = manifest_pb2.AssetFile() + assets_any[0].Unpack(asset) + assets_path = os.path.join(export_path, + exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.ASSETS_DIRECTORY, + "hello42.txt") + asset_contents = gfile.GFile(assets_path).read() + self.assertEqual(asset_contents, "your data here") + self.assertEquals("hello42.txt", asset.filename) + self.assertEquals("filename42:0", asset.tensor_binding.tensor_name) + ignored_asset_path = os.path.join(export_path, + exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.ASSETS_DIRECTORY, + "ignored.txt") + self.assertFalse(gfile.Exists(ignored_asset_path)) + + # Validate graph restoration. + if sharded: + save.restore(sess, + os.path.join( + export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.VARIABLES_FILENAME_PATTERN)) + else: + save.restore(sess, + os.path.join( + export_path, exporter.VERSION_FORMAT_SPECIFIER % + global_step, exporter.VARIABLES_FILENAME)) + self.assertEqual(10, tf.get_collection("v")[0].eval()) + self.assertEqual(20, tf.get_collection("v")[1].eval()) + tf.get_collection(exporter.INIT_OP_KEY)[0].run() + self.assertEqual(30, tf.get_collection("v")[2].eval()) + + def testDuplicateExportRaisesError(self): + export_path = os.path.join(tf.test.get_temp_dir(), "export_duplicates") + self.doBasicsOneExportPath(export_path) + self.assertRaises(RuntimeError, self.doBasicsOneExportPath, export_path) + + def testBasics(self): + export_path = os.path.join(tf.test.get_temp_dir(), "export") + self.doBasicsOneExportPath(export_path) + + def testBasicsNoShard(self): + export_path = os.path.join(tf.test.get_temp_dir(), "export_no_shard") + self.doBasicsOneExportPath(export_path, sharded=False) + + def testClearDevice(self): + export_path = os.path.join(tf.test.get_temp_dir(), "export_clear_device") + self.doBasicsOneExportPath(export_path, clear_devices=True) + + def testGC(self): + export_path = os.path.join(tf.test.get_temp_dir(), "gc") + self.doBasicsOneExportPath(export_path, global_step=100) + self.assertEquals(gfile.ListDirectory(export_path), ["00000100"]) + self.doBasicsOneExportPath(export_path, global_step=101) + self.assertEquals( + sorted(gfile.ListDirectory(export_path)), ["00000100", "00000101"]) + self.doBasicsOneExportPath(export_path, global_step=102) + self.assertEquals( + sorted(gfile.ListDirectory(export_path)), ["00000101", "00000102"]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/session_bundle/gc.py b/tensorflow/contrib/session_bundle/gc.py new file mode 100644 index 0000000000..ad7389d96f --- /dev/null +++ b/tensorflow/contrib/session_bundle/gc.py @@ -0,0 +1,204 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""System for specifying garbage collection (GC) of path based data. + +This framework allows for GC of data specified by path names, for example files +on disk. gc.Path objects each represent a single item stored at a path and may +be a base directory, + /tmp/exports/0/... + /tmp/exports/1/... + ... +or a fully qualified file, + /tmp/train-1.ckpt + /tmp/train-2.ckpt + ... + +A gc filter function takes and returns a list of gc.Path items. Filter +functions are responsible for selecting Path items for preservation or deletion. +Note that functions should always return a sorted list. + +For example, + base_dir = "/tmp" + # create the directories + for e in xrange(10): + os.mkdir("%s/%d" % (base_dir, e), 0o755) + + # create a simple parser that pulls the export_version from the directory + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + path_list = gc.get_paths("/tmp", parser) # contains all ten Paths + + every_fifth = gc.mod_export_version(5) + print every_fifth(path_list) # shows ["/tmp/0", "/tmp/5"] + + largest_three = gc.largest_export_versions(3) + print largest_three(all_paths) # shows ["/tmp/7", "/tmp/8", "/tmp/9"] + + both = gc.union(every_fifth, largest_three) + print both(all_paths) # shows ["/tmp/0", "/tmp/5", + # "/tmp/7", "/tmp/8", "/tmp/9"] + # delete everything not in 'both' + to_delete = gc.negation(both) + for p in to_delete(all_paths): + gfile.DeleteRecursively(p.path) # deletes: "/tmp/1", "/tmp/2", + # "/tmp/3", "/tmp/4", "/tmp/6", +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import heapq +import math +import os + +from six.moves import xrange # pylint: disable=redefined-builtin +from tensorflow.python.platform import gfile + +Path = collections.namedtuple('Path', 'path export_version') + + +def largest_export_versions(n): + """Creates a filter that keeps the largest n export versions. + + Args: + n: number of versions to keep. + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + heap = [] + for idx, path in enumerate(paths): + if path.export_version: + heapq.heappush(heap, (path.export_version, idx)) + keepers = [paths[i] for _, i in heapq.nlargest(n, heap)] + return sorted(keepers) + + return keep + + +def one_of_every_n_export_versions(n): + """Creates a filter that keeps one of every n export versions. + + Args: + n: interval size. + + Returns: + A filter function that keeps exactly one path from each interval + [0, n], (n, 2n], (2n, 3n], etc... If more than one path exists in an + interval the largest is kept. + """ + def keep(paths): + keeper_map = {} # map from interval to largest path seen in that interval + for p in paths: + if p.export_version is None: + # Skip missing export_versions. + continue + # Find the interval (with a special case to map export_version = 0 to + # interval 0. + interval = math.floor( + (p.export_version - 1) / n) if p.export_version else 0 + existing = keeper_map.get(interval, None) + if (not existing) or (existing.export_version < p.export_version): + keeper_map[interval] = p + return sorted(keeper_map.values()) + + return keep + + +def mod_export_version(n): + """Creates a filter that keeps every export that is a multiple of n. + + Args: + n: step size. + + Returns: + A filter function that keeps paths where export_version % n == 0. + """ + def keep(paths): + keepers = [] + for p in paths: + if p.export_version % n == 0: + keepers.append(p) + return sorted(keepers) + return keep + + +def union(lf, rf): + """Creates a filter that keeps the union of two filters. + + Args: + lf: first filter + rf: second filter + + Returns: + A filter function that keeps the n largest paths. + """ + def keep(paths): + l = set(lf(paths)) + r = set(rf(paths)) + return sorted(list(l|r)) + return keep + + +def negation(f): + """Negate a filter. + + Args: + f: filter function to invert + + Returns: + A filter function that returns the negation of f. + """ + def keep(paths): + l = set(paths) + r = set(f(paths)) + return sorted(list(l-r)) + return keep + + +def get_paths(base_dir, parser): + """Gets a list of Paths in a given directory. + + Args: + base_dir: directory. + parser: a function which gets the raw Path and can augment it with + information such as the export_version, or ignore the path by returning + None. An example parser may extract the export version from a path + such as "/tmp/exports/100" an another may extract from a full file + name such as "/tmp/checkpoint-99.out". + + Returns: + A list of Paths contained in the base directory with the parsing function + applied. + By default the following fields are populated, + - Path.path + The parsing function is responsible for populating, + - Path.export_version + """ + raw_paths = gfile.ListDirectory(base_dir) + paths = [] + for r in raw_paths: + p = parser(Path(os.path.join(base_dir, r), None)) + if p: + paths.append(p) + return sorted(paths) diff --git a/tensorflow/contrib/session_bundle/gc_test.py b/tensorflow/contrib/session_bundle/gc_test.py new file mode 100644 index 0000000000..c71cbb9e5f --- /dev/null +++ b/tensorflow/contrib/session_bundle/gc_test.py @@ -0,0 +1,115 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests for session_bundle.gc.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os +import re + + +from six.moves import xrange # pylint: disable=redefined-builtin + +import tensorflow as tf + +from tensorflow.contrib.session_bundle import gc +from tensorflow.python.framework import test_util +from tensorflow.python.platform import gfile + + +def tearDownModule(): + gfile.DeleteRecursively(tf.test.get_temp_dir()) + + +class GcTest(test_util.TensorFlowTestCase): + + def testLargestExportVersions(self): + paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)] + newest = gc.largest_export_versions(2) + n = newest(paths) + self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)]) + + def testModExportVersion(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.mod_export_version(2) + self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)]) + mod = gc.mod_export_version(3) + self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)]) + + def testOneOfEveryNExportVersions(self): + paths = [gc.Path("/foo", 0), gc.Path("/foo", 1), gc.Path("/foo", 3), + gc.Path("/foo", 5), gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 33)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 3), gc.Path("/foo", 6), + gc.Path("/foo", 8), gc.Path("/foo", 33)]) + + def testOneOfEveryNExportVersionsZero(self): + # Zero is a special case since it gets rolled into the first interval. + # Test that here. + paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)] + one_of = gc.one_of_every_n_export_versions(3) + self.assertEquals(one_of(paths), + [gc.Path("/foo", 0), gc.Path("/foo", 5)]) + + def testUnion(self): + paths = [] + for i in xrange(10): + paths.append(gc.Path("/foo", i)) + f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3)) + self.assertEquals( + f(paths), [gc.Path("/foo", 0), gc.Path("/foo", 3), + gc.Path("/foo", 6), gc.Path("/foo", 7), + gc.Path("/foo", 8), gc.Path("/foo", 9)]) + + def testNegation(self): + paths = [gc.Path("/foo", 4), gc.Path("/foo", 5), gc.Path("/foo", 6), + gc.Path("/foo", 9)] + mod = gc.negation(gc.mod_export_version(2)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)]) + mod = gc.negation(gc.mod_export_version(3)) + self.assertEquals( + mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)]) + + def testPathsWithParse(self): + base_dir = os.path.join(tf.test.get_temp_dir(), "paths_parse") + self.assertFalse(gfile.Exists(base_dir)) + for p in xrange(3): + gfile.MakeDirs(os.path.join(base_dir, "%d" % p)) + # add a base_directory to ignore + gfile.MakeDirs(os.path.join(base_dir, "ignore")) + + # create a simple parser that pulls the export_version from the directory. + def parser(path): + match = re.match("^" + base_dir + "/(\\d+)$", path.path) + if not match: + return None + return path._replace(export_version=int(match.group(1))) + + self.assertEquals( + gc.get_paths(base_dir, parser=parser), + [gc.Path(os.path.join(base_dir, "0"), 0), + gc.Path(os.path.join(base_dir, "1"), 1), + gc.Path(os.path.join(base_dir, "2"), 2)]) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/session_bundle/manifest.proto b/tensorflow/contrib/session_bundle/manifest.proto new file mode 100644 index 0000000000..499c1bcfd8 --- /dev/null +++ b/tensorflow/contrib/session_bundle/manifest.proto @@ -0,0 +1,70 @@ +syntax = "proto3"; + +package tensorflow.contrib; + +// Signatures of model export. +message Signatures { + // Default signature of the graph. + // WARNING(break-tutorial-inline-code): The following code snippet is + // in-lined in tutorials, please update tutorial documents accordingly + // whenever code changes. + Signature default_signature = 1; + + // Named signatures of the graph. + map named_signatures = 2; +}; + +// A binding to a tensor including the name and, possibly in the future, type +// or other metadata. For example, this may specify whether a tensor supports +// batch vs single inference. +message TensorBinding { + // The name of the tensor to bind to. + string tensor_name = 1; +}; + +// An asset file or set of sharded files with the same name that will be bound +// to a tensor at init / session_bundle load time. +message AssetFile { + // The tensor to bind the asset filename to. + TensorBinding tensor_binding = 1; + // The filename within the assets directory. Note: does not include the base + // path or asset directory prefix. Base paths can and will change when models + // are deployed for serving. + string filename = 2; +} + +// A Signature specifies the inputs and outputs of commonly used graphs. +message Signature { + oneof type { + RegressionSignature regression_signature = 1; + ClassificationSignature classification_signature = 2; + GenericSignature generic_signature = 3; + } +}; + +// RegressionSignature specifies a graph that takes an input and returns an +// output. +message RegressionSignature { + TensorBinding input = 1; + TensorBinding output = 2; +}; + +// ClassificationSignature specifies a graph that takes an input and returns +// classes and their scores. +// WARNING(break-tutorial-inline-code): The following code snippet is +// in-lined in tutorials, please update tutorial documents accordingly +// whenever code changes. +message ClassificationSignature { + TensorBinding input = 1; + TensorBinding classes = 2; + TensorBinding scores = 3; +}; + +// GenericSignature specifies a map from logical name to Tensor name. +// Typical application of GenericSignature is to use a single GenericSignature +// that includes all of the Tensor nodes and target names that may be useful at +// serving, analysis or debugging time. The recommended name for this signature +// in the ModelManifest is "generic_bindings". +message GenericSignature { + map map = 1; +}; diff --git a/tensorflow/contrib/session_bundle/session_bundle.cc b/tensorflow/contrib/session_bundle/session_bundle.cc new file mode 100644 index 0000000000..8715492af4 --- /dev/null +++ b/tensorflow/contrib/session_bundle/session_bundle.cc @@ -0,0 +1,181 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/session_bundle/session_bundle.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "tensorflow/contrib/session_bundle/manifest.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_types.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace contrib { +namespace { + +// Create a session using the given options and load the graph. +Status CreateSessionFromGraphDef( + const tensorflow::SessionOptions& options, const GraphDef& graph, + std::unique_ptr* session) { + session->reset(NewSession(options)); + return (*session)->Create(graph); +} + +Status GetMetaGraphDefFromExport(const StringPiece export_dir, + tensorflow::MetaGraphDef* meta_graph_def) { + const string meta_graph_def_path = + tensorflow::io::JoinPath(export_dir, kMetaGraphDefFilename); + return ReadBinaryProto(Env::Default(), meta_graph_def_path, meta_graph_def); +} + +// Creates a string tensor. +Tensor CreateStringTensor(const string& value) { + Tensor tensor(DT_STRING, TensorShape({})); + tensor.scalar()() = value; + return tensor; +} + +// Adds Assets related tensors (assets_dir and asset files) to the inputs. +void AddAssetsTensorsToInputs(const StringPiece export_dir, + const std::vector& asset_files, + std::vector>* inputs) { + if (!asset_files.empty()) { + for (auto& asset : asset_files) { + Tensor assets_file_tensor = CreateStringTensor(tensorflow::io::JoinPath( + tensorflow::io::JoinPath(export_dir, kAssetsDirectory), + asset.filename())); + inputs->push_back( + {asset.tensor_binding().tensor_name(), assets_file_tensor}); + } + } +} + +// Historically, model exporter(exporter.py) takes only saver with +// sharded=True, and therefore always exports checkpoint in pattern file names. +// In practice, instead of training from scratch and export directly, we +// usually want to restore from existing checkpoints and then export directly. +// To support such case, model exporter now supports reusing saver object +// restored from existing checkpoint, that may have sharded=False - it will +// then export checkpoint file in plain file name. +// This method is to support models exported by both types of saver object. +// The change is backward-compatible, therefore no changes are needed for +// existing model exports. +string GetVariablesFilename(const StringPiece export_dir) { + const char kVariablesFilename[] = "export"; + const char kVariablesFilenamePattern[] = "export-\?\?\?\?\?-of-\?\?\?\?\?"; + if (Env::Default()->FileExists( + tensorflow::io::JoinPath(export_dir, kVariablesFilename))) { + return tensorflow::io::JoinPath(export_dir, kVariablesFilename); + } else { + return tensorflow::io::JoinPath(export_dir, kVariablesFilenamePattern); + } +} + +Status RunRestoreOp(const StringPiece export_dir, + const std::vector& asset_files, + const StringPiece restore_op_name, + const StringPiece variables_filename_const_op_name, + tensorflow::Session* session) { + LOG(INFO) << "Running restore op for SessionBundle"; + Tensor variables_tensor = CreateStringTensor( + GetVariablesFilename(export_dir)); + std::vector> inputs = { + {variables_filename_const_op_name.ToString(), variables_tensor}}; + AddAssetsTensorsToInputs(export_dir, asset_files, &inputs); + return session->Run(inputs, {}, {restore_op_name.ToString()}, nullptr); +} + +Status RunInitOp(const StringPiece export_dir, + const std::vector& asset_files, + const StringPiece init_op_name, tensorflow::Session* session) { + LOG(INFO) << "Running init op for SessionBundle"; + std::vector> inputs; + AddAssetsTensorsToInputs(export_dir, asset_files, &inputs); + return session->Run(inputs, {}, {init_op_name.ToString()}, nullptr); +} + +} // namespace + +tensorflow::Status LoadSessionBundleFromPath( + const tensorflow::SessionOptions& options, const StringPiece export_dir, + SessionBundle* bundle) { + LOG(INFO) << "Attempting to load a SessionBundle from: " << export_dir; + TF_RETURN_IF_ERROR( + GetMetaGraphDefFromExport(export_dir, &(bundle->meta_graph_def))); + + auto collection_def = bundle->meta_graph_def.collection_def(); + if (collection_def.find(kGraphKey) != collection_def.end()) { + // Use serving graph_def in MetaGraphDef collection_def. + if (collection_def[kGraphKey].any_list().value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one serving GraphDef in : ", + bundle->meta_graph_def.DebugString())); + } + tensorflow::GraphDef graph_def; + collection_def[kGraphKey].any_list().value(0).UnpackTo(&graph_def); + TF_RETURN_IF_ERROR( + CreateSessionFromGraphDef(options, graph_def, &bundle->session)); + } else { + // Fallback to use the graph_def in the MetaGraphDef. + const tensorflow::GraphDef& graph_def = bundle->meta_graph_def.graph_def(); + TF_RETURN_IF_ERROR( + CreateSessionFromGraphDef(options, graph_def, &bundle->session)); + } + + std::vector asset_files; + auto any_assets = collection_def[kAssetsKey].any_list().value(); + for (const auto any_asset : any_assets) { + AssetFile asset_file; + any_asset.UnpackTo(&asset_file); + asset_files.push_back(asset_file); + } + + TF_RETURN_IF_ERROR( + RunRestoreOp(export_dir, asset_files, + bundle->meta_graph_def.saver_def().restore_op_name(), + bundle->meta_graph_def.saver_def().filename_tensor_name(), + bundle->session.get())); + + if (collection_def.find(kInitOpKey) != collection_def.end()) { + if (collection_def[kInitOpKey].node_list().value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one serving init op in : ", + bundle->meta_graph_def.DebugString())); + } + return RunInitOp(export_dir, asset_files, + collection_def[kInitOpKey].node_list().value(0), + bundle->session.get()); + } + + LOG(INFO) << "Done loading SessionBundle"; + return Status::OK(); +} + +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/session_bundle.h b/tensorflow/contrib/session_bundle/session_bundle.h new file mode 100644 index 0000000000..92ed0132b8 --- /dev/null +++ b/tensorflow/contrib/session_bundle/session_bundle.h @@ -0,0 +1,59 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Low-level functionality for setting up a inference Session. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ + +#include + +#include "tensorflow/contrib/session_bundle/manifest.pb.h" +#include "tensorflow/contrib/session_bundle/signature.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace contrib { + +const char kMetaGraphDefFilename[] = "export.meta"; +const char kAssetsDirectory[] = "assets"; +const char kInitOpKey[] = "serving_init_op"; +const char kAssetsKey[] = "serving_assets"; +const char kGraphKey[] = "serving_graph"; + +// Data and objects loaded from a python Exporter export. +// WARNING(break-tutorial-inline-code): The following code snippet is +// in-lined in tutorials, please update tutorial documents accordingly +// whenever code changes. +struct SessionBundle { + std::unique_ptr session; + tensorflow::MetaGraphDef meta_graph_def; +}; + +// Loads a manifest and initialized session using the output of an Exporter +// using the format defined at go/tf-exporter. +tensorflow::Status LoadSessionBundleFromPath( + const tensorflow::SessionOptions& options, + const tensorflow::StringPiece export_dir, SessionBundle* bundle); + +} // namespace contrib +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SESSION_BUNDLE_H_ diff --git a/tensorflow/contrib/session_bundle/session_bundle_test.cc b/tensorflow/contrib/session_bundle/session_bundle_test.cc new file mode 100644 index 0000000000..2bf09ad438 --- /dev/null +++ b/tensorflow/contrib/session_bundle/session_bundle_test.cc @@ -0,0 +1,102 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/session_bundle/session_bundle.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "tensorflow/contrib/session_bundle/test_util.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/public/session.h" +#include "tensorflow/core/public/session_options.h" + +namespace tensorflow { +namespace contrib { +namespace { + +TEST(LoadSessionBundleFromPath, Basic) { + const string export_path = test_util::TestSrcDirPath( + "session_bundle/example/half_plus_two/00000123"); + tensorflow::SessionOptions options; + SessionBundle bundle; + TF_ASSERT_OK(LoadSessionBundleFromPath(options, export_path, &bundle)); + + const string asset_path = + tensorflow::io::JoinPath(export_path, kAssetsDirectory); + // Validate the assets behavior. + std::vector path_outputs; + TF_ASSERT_OK(bundle.session->Run({}, {"filename1:0", "filename2:0"}, {}, + &path_outputs)); + ASSERT_EQ(2, path_outputs.size()); + // Validate the two asset file tensors are set by the init_op and include the + // base_path and asset directory. + test::ExpectTensorEqual( + test::AsTensor( + {tensorflow::io::JoinPath(asset_path, "hello1.txt")}, + TensorShape({})), + path_outputs[0]); + test::ExpectTensorEqual( + test::AsTensor( + {tensorflow::io::JoinPath(asset_path, "hello2.txt")}, + TensorShape({})), + path_outputs[1]); + + // Validate the half plus two behavior. + Tensor input = test::AsTensor({0, 1, 2, 3}, TensorShape({4, 1})); + + // Recover the Tensor names of our inputs and outputs. + auto collection_def = bundle.meta_graph_def.collection_def(); + Signatures signatures; + ASSERT_EQ(1, collection_def[kSignaturesKey].any_list().value_size()); + collection_def[kSignaturesKey].any_list().value(0).UnpackTo(&signatures); + ASSERT_TRUE(signatures.default_signature().has_regression_signature()); + const tensorflow::contrib::RegressionSignature regression_signature = + signatures.default_signature().regression_signature(); + + const string input_name = regression_signature.input().tensor_name(); + const string output_name = regression_signature.output().tensor_name(); + + std::vector outputs; + TF_ASSERT_OK( + bundle.session->Run({{input_name, input}}, {output_name}, {}, &outputs)); + ASSERT_EQ(outputs.size(), 1); + test::ExpectTensorEqual( + outputs[0], test::AsTensor({2, 2.5, 3, 3.5}, TensorShape({4, 1}))); +} + +TEST(LoadSessionBundleFromPath, BadExportPath) { + const string export_path = test_util::TestSrcDirPath("/tmp/bigfoot"); + tensorflow::SessionOptions options; + options.target = "local"; + SessionBundle bundle; + const auto status = LoadSessionBundleFromPath(options, export_path, &bundle); + ASSERT_FALSE(status.ok()); + const string msg = status.ToString(); + EXPECT_TRUE(msg.find("Not found") != std::string::npos) << msg; +} + +} // namespace +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/signature.cc b/tensorflow/contrib/session_bundle/signature.cc new file mode 100644 index 0000000000..3550a7d10d --- /dev/null +++ b/tensorflow/contrib/session_bundle/signature.cc @@ -0,0 +1,270 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/session_bundle/signature.h" + +#include +#include +#include + +#include "google/protobuf/any.pb.h" +#include "tensorflow/contrib/session_bundle/manifest.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/strings/strcat.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace contrib { +namespace { + +// Returns OK if the input and output batch sizes match. +Status BatchSizesMatch(const Tensor& input, const Tensor& output) { + // Ensure the number of outputs match the number of inputs. + if (input.dim_size(0) != output.dim_size(0)) { + return errors::Internal( + strings::StrCat("Input batch size did not match output batch size: ", + input.dim_size(0), " vs. ", output.dim_size(0))); + } + return Status::OK(); +} +} // namespace + +Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def, + Signatures* signatures) { + auto collection_def = meta_graph_def.collection_def(); + auto any_list = collection_def[kSignaturesKey].any_list(); + if (any_list.value_size() != 1) { + return errors::FailedPrecondition( + strings::StrCat("Expected exactly one signatures proto in : ", + meta_graph_def.DebugString())); + } + any_list.value(0).UnpackTo(signatures); + return Status::OK(); +} + +Status SetSignatures(const Signatures& signatures, + tensorflow::MetaGraphDef* meta_graph_def) { + auto& collection_def = *(meta_graph_def->mutable_collection_def()); + auto* any_list = collection_def[kSignaturesKey].mutable_any_list(); + any_list->mutable_value()->Clear(); + any_list->mutable_value()->Add()->PackFrom(signatures); + return Status::OK(); +} + +Status GetClassificationSignature( + const tensorflow::MetaGraphDef& meta_graph_def, + ClassificationSignature* signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + if (!signatures.has_default_signature()) { + return errors::FailedPrecondition(strings::StrCat( + "Expected a default signature in: ", signatures.DebugString())); + } + if (!signatures.default_signature().has_classification_signature()) { + return errors::FailedPrecondition( + strings::StrCat("Expected a classification signature in: ", + signatures.default_signature().DebugString())); + } + *signature = signatures.default_signature().classification_signature(); + return Status::OK(); +} + +Status GetNamedClassificationSignature( + const string& name, const tensorflow::MetaGraphDef& meta_graph_def, + ClassificationSignature* signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + const auto& it = signatures.named_signatures().find(name); + if (it == signatures.named_signatures().end()) { + return errors::NotFound(strings::StrCat("Missing signature named \"", name, + "\" in: ", + signatures.DebugString())); + } + if (!it->second.has_classification_signature()) { + return errors::FailedPrecondition( + strings::StrCat("Expected a classification signature for name \"", name, + "\" in: ", it->second.DebugString())); + } + *signature = it->second.classification_signature(); + return Status::OK(); +} + +Status RunClassification(const ClassificationSignature& signature, + const Tensor& input, Session* session, Tensor* classes, + Tensor* scores) { + std::vector output_tensor_names; + if (classes) { + output_tensor_names.push_back(signature.classes().tensor_name()); + } + if (scores) { + output_tensor_names.push_back(signature.scores().tensor_name()); + } + // Run the graph with our inputs and outputs. + std::vector outputs; + const Status run_status = + session->Run({{signature.input().tensor_name(), input}}, + output_tensor_names, {}, &outputs); + if (!run_status.ok()) { + return run_status; + } + // Ensure the output is shaped how we expect. + // There should be one string Tensor of shape, + // [batch_size, num_recommendations]. + if (outputs.size() != output_tensor_names.size()) { + return errors::Internal( + strings::StrCat("Expected ", output_tensor_names.size(), + " output tensor(s). Got: ", outputs.size())); + } + if (classes) { + *classes = outputs[0]; + TF_RETURN_IF_ERROR(BatchSizesMatch(input, *classes)); + } + if (scores) { + *scores = outputs[classes ? 1 : 0]; + TF_RETURN_IF_ERROR(BatchSizesMatch(input, *scores)); + } + return Status::OK(); +} + +Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def, + RegressionSignature* signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + if (!signatures.has_default_signature()) { + return errors::FailedPrecondition(strings::StrCat( + "Expected a default signature in: ", signatures.DebugString())); + } + if (!signatures.default_signature().has_regression_signature()) { + return errors::FailedPrecondition( + strings::StrCat("Expected a regression signature in: ", + signatures.default_signature().DebugString())); + } + *signature = signatures.default_signature().regression_signature(); + return Status::OK(); +} + +Status RunRegression(const RegressionSignature& signature, + const Tensor& regression_input, Session* session, + Tensor* regression_output) { + std::vector output_tensor_names; + if (regression_output) { + output_tensor_names.push_back(signature.output().tensor_name()); + } + // Run the graph with our inputs and outputs. + std::vector outputs; + const Status run_status = + session->Run({{signature.input().tensor_name(), regression_input}}, + output_tensor_names, {}, &outputs); + if (!run_status.ok()) { + return run_status; + } + // Ensure the regression score output is shaped how we expect. + // There should be one float Tensor of shape, + // [batch_size, num_recommendations]. + if (outputs.size() != output_tensor_names.size()) { + return errors::Internal( + strings::StrCat("Expected ", output_tensor_names.size(), + " output tensor(s). Got: ", outputs.size())); + } + if (regression_output) { + *regression_output = outputs[0]; + TF_RETURN_IF_ERROR(BatchSizesMatch(regression_input, *regression_output)); + } + return Status::OK(); +} + +Status GetGenericSignature(const string& name, + const tensorflow::MetaGraphDef& meta_graph_def, + GenericSignature* signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + const auto& it = signatures.named_signatures().find(name); + if (it == signatures.named_signatures().end()) { + return errors::InvalidArgument( + strings::StrCat("Missing generic signature named \"", name, "\" in ", + signatures.DebugString())); + } + if (!it->second.has_generic_signature()) { + return errors::InvalidArgument(strings::StrCat( + "Expected a generic signature: ", it->second.DebugString())); + } + *signature = it->second.generic_signature(); + return Status::OK(); +} + +Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def, + Signature* default_signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + *default_signature = signatures.default_signature(); + return Status::OK(); +} + +Status GetNamedSignature(const string& name, + const tensorflow::MetaGraphDef& meta_graph_def, + Signature* signature) { + Signatures signatures; + TF_RETURN_IF_ERROR(GetSignatures(meta_graph_def, &signatures)); + const auto& it = signatures.named_signatures().find(name); + if (it == signatures.named_signatures().end()) { + return errors::NotFound(strings::StrCat("Missing signature named \"", name, + "\" in: ", + signatures.DebugString())); + } + *signature = it->second; + return Status::OK(); +} + +Status BindGenericInputs(const GenericSignature& signature, + const std::vector>& inputs, + std::vector>* bound_inputs) { + const protobuf::Map& bindings = + signature.map(); + + for (const auto& entry : inputs) { + const auto mapped = bindings.find(entry.first); + if (mapped == bindings.end()) { + return errors::NotFound( + strings::StrCat("Could not find generic binding for: ", entry.first)); + } + bound_inputs->push_back({mapped->second.tensor_name(), entry.second}); + } + return Status::OK(); +} + +Status BindGenericNames(const GenericSignature& signature, + const std::vector& input_names, + std::vector* bound_names) { + const protobuf::Map& bindings = + signature.map(); + + for (const string& entry : input_names) { + const auto mapped = bindings.find(entry); + if (mapped == bindings.end()) { + return errors::NotFound( + strings::StrCat("Could not find generic binding for: ", entry)); + } + bound_names->push_back(mapped->second.tensor_name()); + } + return Status::OK(); +} + +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/signature.h b/tensorflow/contrib/session_bundle/signature.h new file mode 100644 index 0000000000..98d6e56601 --- /dev/null +++ b/tensorflow/contrib/session_bundle/signature.h @@ -0,0 +1,123 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Helpers for working with TensorFlow exports and their signatures. + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ + +#include +#include +#include + +#include "tensorflow/contrib/session_bundle/manifest.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace contrib { + +const char kSignaturesKey[] = "serving_signatures"; + +// Get Signatures from a MetaGraphDef. +Status GetSignatures(const tensorflow::MetaGraphDef& meta_graph_def, + Signatures* signatures); + +// (Re)set Signatures in a MetaGraphDef. +Status SetSignatures(const Signatures& signatures, + tensorflow::MetaGraphDef* meta_graph_def); + +// Gets a ClassificationSignature from a MetaGraphDef's default signature. +// Returns an error if the default signature is not a ClassificationSignature, +// or does not exist. +Status GetClassificationSignature( + const tensorflow::MetaGraphDef& meta_graph_def, + ClassificationSignature* signature); + +// Gets a named ClassificationSignature from a MetaGraphDef. +// Returns an error if a ClassificationSignature with the given name does +// not exist. +Status GetNamedClassificationSignature( + const string& name, const tensorflow::MetaGraphDef& meta_graph_def, + ClassificationSignature* signature); + +// Gets a RegressionSignature from a MetaGraphDef's default signature. +// Returns an error if the default signature is not a RegressionSignature, +// or does not exist. +Status GetRegressionSignature(const tensorflow::MetaGraphDef& meta_graph_def, + RegressionSignature* signature); + +// Runs a classification using the provided signature and initialized Session. +// input: input batch of items to classify +// classes: output batch of classes; may be null if not needed +// scores: output batch of scores; may be null if not needed +// Validates sizes of the inputs and outputs are consistent (e.g., input +// batch size equals output batch sizes). +// Does not do any type validation. +Status RunClassification(const ClassificationSignature& signature, + const Tensor& input, Session* session, Tensor* classes, + Tensor* scores); + +// Runs regression using the provided signature and initialized Session. +// input: input batch of items to run the regression model against +// output: output targets +// Validates sizes of the inputs and outputs are consistent (e.g., input +// batch size equals output batch sizes). +// Does not do any type validation. +Status RunRegression(const RegressionSignature& signature, const Tensor& input, + Session* session, Tensor* output); + +// Gets the named GenericSignature from a MetaGraphDef. +// Returns an error if a GenericSignature with the given name does not exist. +Status GetGenericSignature(const string& name, + const tensorflow::MetaGraphDef& meta_graph_def, + GenericSignature* signature); + +// Gets the default signature from a MetaGraphDef. +Status GetDefaultSignature(const tensorflow::MetaGraphDef& meta_graph_def, + Signature* default_signature); + +// Gets a named Signature from a MetaGraphDef. +// Returns an error if a Signature with the given name does not exist. +Status GetNamedSignature(const string& name, + const tensorflow::MetaGraphDef& meta_graph_def, + Signature* default_signature); + +// Binds TensorFlow inputs specified by the caller using the logical names +// specified at Graph export time, to the actual Graph names. +// Returns an error if any of the inputs do not have a binding in the export's +// MetaGraphDef. +Status BindGenericInputs(const GenericSignature& signature, + const std::vector>& inputs, + std::vector>* bound_inputs); + +// Binds the input names specified by the caller using the logical names +// specified at Graph export time, to the actual Graph names. This is useful +// for binding names of both the TensorFlow output tensors and target nodes, +// with the latter (target nodes) being optional and rarely used (if ever) at +// serving time. +// Returns an error if any of the input names do not have a binding in the +// export's MetaGraphDef. +Status BindGenericNames(const GenericSignature& signature, + const std::vector& input_names, + std::vector* bound_names); +} // namespace contrib +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_SIGNATURE_H_ diff --git a/tensorflow/contrib/session_bundle/signature_test.cc b/tensorflow/contrib/session_bundle/signature_test.cc new file mode 100644 index 0000000000..abeaf23c0e --- /dev/null +++ b/tensorflow/contrib/session_bundle/signature_test.cc @@ -0,0 +1,602 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/session_bundle/signature.h" + +#include + +#include "google/protobuf/any.pb.h" +#include "tensorflow/contrib/session_bundle/manifest.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/framework/tensor_testutil.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/lib/core/status.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/lib/core/stringpiece.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace contrib { +namespace { + +static bool HasSubstr(const string& base, const string& substr) { + bool ok = StringPiece(base).contains(substr); + EXPECT_TRUE(ok) << base << ", expected substring " << substr; + return ok; +} + +TEST(GetClassificationSignature, Basic) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + ClassificationSignature* input_signature = + signatures.mutable_default_signature() + ->mutable_classification_signature(); + input_signature->mutable_input()->set_tensor_name("flow"); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = GetClassificationSignature(meta_graph_def, &signature); + TF_ASSERT_OK(status); + EXPECT_EQ(signature.input().tensor_name(), "flow"); +} + +TEST(GetClassificationSignature, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_default_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = GetClassificationSignature(meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected a classification signature")) + << status.error_message(); +} + +TEST(GetClassificationSignature, WrongSignatureType) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_default_signature()->mutable_regression_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = GetClassificationSignature(meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected a classification signature")) + << status.error_message(); +} + +TEST(GetNamedClassificationSignature, Basic) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + ClassificationSignature* input_signature = + (*signatures.mutable_named_signatures())["foo"] + .mutable_classification_signature(); + input_signature->mutable_input()->set_tensor_name("flow"); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = + GetNamedClassificationSignature("foo", meta_graph_def, &signature); + TF_ASSERT_OK(status); + EXPECT_EQ(signature.input().tensor_name(), "flow"); +} + +TEST(GetNamedClassificationSignature, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = + GetNamedClassificationSignature("foo", meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Missing signature named \"foo\"")) + << status.error_message(); +} + +TEST(GetNamedClassificationSignature, WrongSignatureType) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + (*signatures.mutable_named_signatures())["foo"] + .mutable_regression_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + ClassificationSignature signature; + const Status status = + GetNamedClassificationSignature("foo", meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE( + StringPiece(status.error_message()) + .contains("Expected a classification signature for name \"foo\"")) + << status.error_message(); +} + +TEST(GetRegressionSignature, Basic) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + RegressionSignature* input_signature = + signatures.mutable_default_signature()->mutable_regression_signature(); + input_signature->mutable_input()->set_tensor_name("flow"); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + RegressionSignature signature; + const Status status = GetRegressionSignature(meta_graph_def, &signature); + TF_ASSERT_OK(status); + EXPECT_EQ(signature.input().tensor_name(), "flow"); +} + +TEST(GetRegressionSignature, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_default_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + RegressionSignature signature; + const Status status = GetRegressionSignature(meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected a regression signature")) + << status.error_message(); +} + +TEST(GetRegressionSignature, WrongSignatureType) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_default_signature()->mutable_classification_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + RegressionSignature signature; + const Status status = GetRegressionSignature(meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected a regression signature")) + << status.error_message(); +} + +TEST(GetNamedSignature, Basic) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + ClassificationSignature* input_signature = + (*signatures.mutable_named_signatures())["foo"] + .mutable_classification_signature(); + input_signature->mutable_input()->set_tensor_name("flow"); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + Signature signature; + const Status status = GetNamedSignature("foo", meta_graph_def, &signature); + TF_ASSERT_OK(status); + EXPECT_EQ(signature.classification_signature().input().tensor_name(), "flow"); +} + +TEST(GetNamedSignature, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + Signature signature; + const Status status = GetNamedSignature("foo", meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Missing signature named \"foo\"")) + << status.error_message(); +} + +// MockSession used to test input and output interactions with a +// tensorflow::Session. +struct MockSession : public tensorflow::Session { + ~MockSession() override = default; + + Status Create(const GraphDef& graph) override { + return errors::Unimplemented("Not implemented for mock."); + } + + Status Extend(const GraphDef& graph) override { + return errors::Unimplemented("Not implemented for mock."); + } + + // Sets the input and output arguments. + Status Run(const std::vector>& inputs_arg, + const std::vector& output_tensor_names_arg, + const std::vector& target_node_names_arg, + std::vector* outputs_arg) override { + inputs = inputs_arg; + output_tensor_names = output_tensor_names_arg; + target_node_names = target_node_names_arg; + *outputs_arg = outputs; + return status; + } + + Status Close() override { + return errors::Unimplemented("Not implemented for mock."); + } + + // Arguments stored on a Run call. + std::vector> inputs; + std::vector output_tensor_names; + std::vector target_node_names; + + // Output argument set by Run; should be set before calling. + std::vector outputs; + + // Return value for Run; should be set before calling. + Status status; +}; + +constexpr char kInputName[] = "in:0"; +constexpr char kClassesName[] = "classes:0"; +constexpr char kScoresName[] = "scores:0"; + +class RunClassificationTest : public ::testing::Test { + public: + void SetUp() override { + signature_.mutable_input()->set_tensor_name(kInputName); + signature_.mutable_classes()->set_tensor_name(kClassesName); + signature_.mutable_scores()->set_tensor_name(kScoresName); + } + + protected: + ClassificationSignature signature_; + Tensor input_tensor_; + Tensor classes_tensor_; + Tensor scores_tensor_; + MockSession session_; +}; + +TEST_F(RunClassificationTest, Basic) { + input_tensor_ = test::AsTensor({99}); + session_.outputs = {test::AsTensor({3}), test::AsTensor({2})}; + const Status status = RunClassification(signature_, input_tensor_, &session_, + &classes_tensor_, &scores_tensor_); + + // Validate outputs. + TF_ASSERT_OK(status); + test::ExpectTensorEqual(test::AsTensor({3}), classes_tensor_); + test::ExpectTensorEqual(test::AsTensor({2}), scores_tensor_); + + // Validate inputs. + ASSERT_EQ(1, session_.inputs.size()); + EXPECT_EQ(kInputName, session_.inputs[0].first); + test::ExpectTensorEqual(test::AsTensor({99}), + session_.inputs[0].second); + + ASSERT_EQ(2, session_.output_tensor_names.size()); + EXPECT_EQ(kClassesName, session_.output_tensor_names[0]); + EXPECT_EQ(kScoresName, session_.output_tensor_names[1]); +} + +TEST_F(RunClassificationTest, ClassesOnly) { + input_tensor_ = test::AsTensor({99}); + session_.outputs = {test::AsTensor({3})}; + const Status status = RunClassification(signature_, input_tensor_, &session_, + &classes_tensor_, nullptr); + + // Validate outputs. + TF_ASSERT_OK(status); + test::ExpectTensorEqual(test::AsTensor({3}), classes_tensor_); + + // Validate inputs. + ASSERT_EQ(1, session_.inputs.size()); + EXPECT_EQ(kInputName, session_.inputs[0].first); + test::ExpectTensorEqual(test::AsTensor({99}), + session_.inputs[0].second); + + ASSERT_EQ(1, session_.output_tensor_names.size()); + EXPECT_EQ(kClassesName, session_.output_tensor_names[0]); +} + +TEST_F(RunClassificationTest, ScoresOnly) { + input_tensor_ = test::AsTensor({99}); + session_.outputs = {test::AsTensor({2})}; + const Status status = RunClassification(signature_, input_tensor_, &session_, + nullptr, &scores_tensor_); + + // Validate outputs. + TF_ASSERT_OK(status); + test::ExpectTensorEqual(test::AsTensor({2}), scores_tensor_); + + // Validate inputs. + ASSERT_EQ(1, session_.inputs.size()); + EXPECT_EQ(kInputName, session_.inputs[0].first); + test::ExpectTensorEqual(test::AsTensor({99}), + session_.inputs[0].second); + + ASSERT_EQ(1, session_.output_tensor_names.size()); + EXPECT_EQ(kScoresName, session_.output_tensor_names[0]); +} + +TEST(RunClassification, RunNotOk) { + ClassificationSignature signature; + signature.mutable_input()->set_tensor_name("in:0"); + signature.mutable_classes()->set_tensor_name("classes:0"); + Tensor input_tensor = test::AsTensor({99}); + MockSession session; + session.status = errors::DataLoss("Data is gone"); + Tensor classes_tensor; + const Status status = RunClassification(signature, input_tensor, &session, + &classes_tensor, nullptr); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone")) + << status.error_message(); +} + +TEST(RunClassification, TooManyOutputs) { + ClassificationSignature signature; + signature.mutable_input()->set_tensor_name("in:0"); + signature.mutable_classes()->set_tensor_name("classes:0"); + Tensor input_tensor = test::AsTensor({99}); + MockSession session; + session.outputs = {test::AsTensor({3}), test::AsTensor({4})}; + + Tensor classes_tensor; + const Status status = RunClassification(signature, input_tensor, &session, + &classes_tensor, nullptr); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("Expected 1 output")) + << status.error_message(); +} + +TEST(RunClassification, WrongBatchOutputs) { + ClassificationSignature signature; + signature.mutable_input()->set_tensor_name("in:0"); + signature.mutable_classes()->set_tensor_name("classes:0"); + Tensor input_tensor = test::AsTensor({99, 100}); + MockSession session; + session.outputs = {test::AsTensor({3})}; + + Tensor classes_tensor; + const Status status = RunClassification(signature, input_tensor, &session, + &classes_tensor, nullptr); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Input batch size did not match output batch size")) + << status.error_message(); +} + +constexpr char kRegressionsName[] = "regressions:0"; + +class RunRegressionTest : public ::testing::Test { + public: + void SetUp() override { + signature_.mutable_input()->set_tensor_name(kInputName); + signature_.mutable_output()->set_tensor_name(kRegressionsName); + } + + protected: + RegressionSignature signature_; + Tensor input_tensor_; + Tensor output_tensor_; + MockSession session_; +}; + +TEST_F(RunRegressionTest, Basic) { + input_tensor_ = test::AsTensor({99, 100}); + session_.outputs = {test::AsTensor({1, 2})}; + const Status status = + RunRegression(signature_, input_tensor_, &session_, &output_tensor_); + + // Validate outputs. + TF_ASSERT_OK(status); + test::ExpectTensorEqual(test::AsTensor({1, 2}), output_tensor_); + + // Validate inputs. + ASSERT_EQ(1, session_.inputs.size()); + EXPECT_EQ(kInputName, session_.inputs[0].first); + test::ExpectTensorEqual(test::AsTensor({99, 100}), + session_.inputs[0].second); + + ASSERT_EQ(1, session_.output_tensor_names.size()); + EXPECT_EQ(kRegressionsName, session_.output_tensor_names[0]); +} + +TEST_F(RunRegressionTest, RunNotOk) { + input_tensor_ = test::AsTensor({99}); + session_.status = errors::DataLoss("Data is gone"); + const Status status = + RunRegression(signature_, input_tensor_, &session_, &output_tensor_); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()).contains("Data is gone")) + << status.error_message(); +} + +TEST_F(RunRegressionTest, MismatchedSizeForBatchInputAndOutput) { + input_tensor_ = test::AsTensor({99, 100}); + session_.outputs = {test::AsTensor({3})}; + + const Status status = + RunRegression(signature_, input_tensor_, &session_, &output_tensor_); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Input batch size did not match output batch size")) + << status.error_message(); +} + +TEST(SetAndGetSignatures, RoundTrip) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_default_signature() + ->mutable_classification_signature() + ->mutable_input() + ->set_tensor_name("in:0"); + TF_ASSERT_OK(SetSignatures(signatures, &meta_graph_def)); + Signatures read_signatures; + TF_ASSERT_OK(GetSignatures(meta_graph_def, &read_signatures)); + + EXPECT_EQ("in:0", read_signatures.default_signature() + .classification_signature() + .input() + .tensor_name()); +} + +// GenericSignature test fixture that contains a signature initialized with two +// bound Tensors. +class GenericSignatureTest : public ::testing::Test { + protected: + GenericSignatureTest() { + TensorBinding binding; + binding.set_tensor_name("graph_A"); + signature_.mutable_map()->insert({"logical_A", binding}); + + binding.set_tensor_name("graph_B"); + signature_.mutable_map()->insert({"logical_B", binding}); + } + + // GenericSignature that contains two bound Tensors. + GenericSignature signature_; +}; + +// GenericSignature tests. + +TEST_F(GenericSignatureTest, GetGenericSignatureBasic) { + Signature expected_signature; + expected_signature.mutable_generic_signature()->MergeFrom(signature_); + + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + signatures.mutable_named_signatures()->insert( + {"generic_bindings", expected_signature}); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + GenericSignature actual_signature; + TF_ASSERT_OK(GetGenericSignature("generic_bindings", meta_graph_def, + &actual_signature)); + ASSERT_EQ("graph_A", actual_signature.map().at("logical_A").tensor_name()); + ASSERT_EQ("graph_B", actual_signature.map().at("logical_B").tensor_name()); +} + +TEST(GetGenericSignature, MissingSignature) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + GenericSignature signature; + const Status status = + GetGenericSignature("generic_bindings", meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(HasSubstr(status.error_message(), + "Missing generic signature named \"generic_bindings\"")) + << status.error_message(); +} + +TEST(GetGenericSignature, WrongSignatureType) { + tensorflow::MetaGraphDef meta_graph_def; + Signatures signatures; + (*signatures.mutable_named_signatures())["generic_bindings"] + .mutable_regression_signature(); + (*meta_graph_def.mutable_collection_def())[kSignaturesKey] + .mutable_any_list() + ->add_value() + ->PackFrom(signatures); + + GenericSignature signature; + const Status status = + GetGenericSignature("generic_bindings", meta_graph_def, &signature); + ASSERT_FALSE(status.ok()); + EXPECT_TRUE(StringPiece(status.error_message()) + .contains("Expected a generic signature:")) + << status.error_message(); +} + +// BindGeneric Tests. + +TEST_F(GenericSignatureTest, BindGenericInputsBasic) { + const std::vector> inputs = { + {"logical_A", test::AsTensor({-1.0})}, + {"logical_B", test::AsTensor({-2.0})}}; + + std::vector> bound_inputs; + TF_ASSERT_OK(BindGenericInputs(signature_, inputs, &bound_inputs)); + + EXPECT_EQ("graph_A", bound_inputs[0].first); + EXPECT_EQ("graph_B", bound_inputs[1].first); + test::ExpectTensorEqual(test::AsTensor({-1.0}), + bound_inputs[0].second); + test::ExpectTensorEqual(test::AsTensor({-2.0}), + bound_inputs[1].second); +} + +TEST_F(GenericSignatureTest, BindGenericInputsMissingBinding) { + const std::vector> inputs = { + {"logical_A", test::AsTensor({-42.0})}, + {"logical_MISSING", test::AsTensor({-43.0})}}; + + std::vector> bound_inputs; + const Status status = BindGenericInputs(signature_, inputs, &bound_inputs); + ASSERT_FALSE(status.ok()); +} + +TEST_F(GenericSignatureTest, BindGenericNamesBasic) { + const std::vector input_names = {"logical_B", "logical_A"}; + std::vector bound_names; + TF_ASSERT_OK(BindGenericNames(signature_, input_names, &bound_names)); + + EXPECT_EQ("graph_B", bound_names[0]); + EXPECT_EQ("graph_A", bound_names[1]); +} + +TEST_F(GenericSignatureTest, BindGenericNamesMissingBinding) { + const std::vector input_names = {"logical_B", "logical_MISSING"}; + std::vector bound_names; + const Status status = BindGenericNames(signature_, input_names, &bound_names); + ASSERT_FALSE(status.ok()); +} + +} // namespace +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/test_util.cc b/tensorflow/contrib/session_bundle/test_util.cc new file mode 100644 index 0000000000..90870a69a0 --- /dev/null +++ b/tensorflow/contrib/session_bundle/test_util.cc @@ -0,0 +1,35 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/contrib/session_bundle/test_util.h" + +#include "tensorflow/core/lib/io/path.h" +#include "tensorflow/core/platform/env.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace contrib { +namespace test_util { + +string TestSrcDirPath(const string& relative_path) { + const string base_path = tensorflow::testing::TensorFlowSrcRoot(); + const string contrib_path = tensorflow::io::JoinPath( + tensorflow::testing::TensorFlowSrcRoot(), "/contrib"); + return tensorflow::io::JoinPath(contrib_path, relative_path); +} + +} // namespace test_util +} // namespace contrib +} // namespace tensorflow diff --git a/tensorflow/contrib/session_bundle/test_util.h b/tensorflow/contrib/session_bundle/test_util.h new file mode 100644 index 0000000000..87ddf28d60 --- /dev/null +++ b/tensorflow/contrib/session_bundle/test_util.h @@ -0,0 +1,38 @@ +/* Copyright 2016 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ +#define THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ + +#include + +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/platform/types.h" + +namespace tensorflow { +namespace contrib { +namespace test_util { + +// Creates an absolute test srcdir path to the linked in runfiles given a path +// relative to third_party/tensorflow/contrib/. +// e.g. relative path = "session_bundle/example". +string TestSrcDirPath(const string& relative_path); + +} // namespace test_util +} // namespace contrib +} // namespace tensorflow + +#endif // THIRD_PARTY_TENSORFLOW_CONTRIB_SESSION_BUNDLE_TEST_UTIL_H_ diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index a14bd88f69..a581923acf 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -30,6 +30,9 @@ sh_binary( ":other_headers", ":simple_console", "//tensorflow:tensorflow_py", + "//tensorflow/contrib/session_bundle:all_files", + "//tensorflow/contrib/session_bundle:manifest_proto_py", + "//tensorflow/contrib/session_bundle/example:half_plus_two", "//tensorflow/contrib/slim:all_files", "//tensorflow/core:framework_headers", "//tensorflow/examples/tutorials/mnist:package", diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template index d3e70e7a4f..d4dddb5211 100644 --- a/tools/bazel.rc.template +++ b/tools/bazel.rc.template @@ -6,6 +6,12 @@ build --python$PYTHON_MAJOR_VERSION_path=$PYTHON_BINARY build --define=use_fast_cpp_protos=true build --define=allow_oversize_protos=true +build --define PYTHON_BIN_PATH=$PYTHON_BINARY +test --define PYTHON_BIN_PATH=$PYTHON_BINARY +test --force_python=py$PYTHON_MAJOR_VERSION +test --host_force_python=py$PYTHON_MAJOR_VERSION +run --define PYTHON_BIN_PATH=$PYTHON_BINARY + build --spawn_strategy=standalone test --spawn_strategy=standalone run --spawn_strategy=standalone -- cgit v1.2.3