aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/exporter.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/session_bundle/exporter.py')
-rw-r--r--tensorflow/contrib/session_bundle/exporter.py311
1 files changed, 311 insertions, 0 deletions
diff --git a/tensorflow/contrib/session_bundle/exporter.py b/tensorflow/contrib/session_bundle/exporter.py
new file mode 100644
index 0000000000..fc7458db62
--- /dev/null
+++ b/tensorflow/contrib/session_bundle/exporter.py
@@ -0,0 +1,311 @@
+# Copyright 2016 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Export a TensorFlow model.
+
+See: go/tf-exporter
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import re
+import six
+
+import tensorflow as tf
+from google.protobuf.any_pb2 import Any
+
+from tensorflow.contrib.session_bundle import gc
+from tensorflow.contrib.session_bundle import manifest_pb2
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.training import training_util
+from tensorflow.python.util import compat
+
+# See: go/tf-exporter for these constants and directory structure.
+VERSION_FORMAT_SPECIFIER = "%08d"
+ASSETS_DIRECTORY = "assets"
+EXPORT_BASE_NAME = "export"
+EXPORT_SUFFIX_NAME = "meta"
+META_GRAPH_DEF_FILENAME = EXPORT_BASE_NAME + "." + EXPORT_SUFFIX_NAME
+VARIABLES_FILENAME = EXPORT_BASE_NAME
+VARIABLES_FILENAME_PATTERN = VARIABLES_FILENAME + "-?????-of-?????"
+INIT_OP_KEY = "serving_init_op"
+SIGNATURES_KEY = "serving_signatures"
+ASSETS_KEY = "serving_assets"
+GRAPH_KEY = "serving_graph"
+
+
+def gfile_copy_callback(files_to_copy, export_dir_path):
+ """Callback to copy files using `gfile.Copy` to an export directory.
+
+ This method is used as the default `assets_callback` in `Exporter.init` to
+ copy assets from the `assets_collection`. It can also be invoked directly to
+ copy additional supplementary files into the export directory (in which case
+ it is not a callback).
+
+ Args:
+ files_to_copy: A dictionary that maps original file paths to desired
+ basename in the export directory.
+ export_dir_path: Directory to copy the files to.
+ """
+ tf.logging.info("Write assest into: %s using gfile_copy.", export_dir_path)
+ gfile.MakeDirs(export_dir_path)
+ for source_filepath, basename in files_to_copy.items():
+ new_path = os.path.join(
+ compat.as_bytes(export_dir_path), compat.as_bytes(basename))
+ tf.logging.info("Copying asset %s to path %s.", source_filepath, new_path)
+
+ if gfile.Exists(new_path):
+ # Guard against being restarted while copying assets, and the file
+ # existing and being in an unknown state.
+ # TODO(b/28676216): Do some file checks before deleting.
+ tf.logging.info("Removing file %s.", new_path)
+ gfile.Remove(new_path)
+ tf.gfile.Copy(source_filepath, new_path)
+
+
+def regression_signature(input_tensor, output_tensor):
+ """Creates a regression signature.
+
+ Args:
+ input_tensor: Tensor specifying the input to a graph.
+ output_tensor: Tensor specifying the output of a graph.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ signature.regression_signature.input.tensor_name = input_tensor.name
+ signature.regression_signature.output.tensor_name = output_tensor.name
+ return signature
+
+
+def classification_signature(input_tensor,
+ classes_tensor=None,
+ scores_tensor=None):
+ """Creates a classification signature.
+
+ Args:
+ input_tensor: Tensor specifying the input to a graph.
+ classes_tensor: Tensor specifying the output classes of a graph.
+ scores_tensor: Tensor specifying the scores of the output classes.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ signature.classification_signature.input.tensor_name = input_tensor.name
+ if classes_tensor is not None:
+ signature.classification_signature.classes.tensor_name = classes_tensor.name
+ if scores_tensor is not None:
+ signature.classification_signature.scores.tensor_name = scores_tensor.name
+ return signature
+
+
+def generic_signature(name_tensor_map):
+ """Creates a generic signature of name to Tensor name.
+
+ Args:
+ name_tensor_map: Map from logical name to Tensor.
+
+ Returns:
+ A Signature message.
+ """
+ signature = manifest_pb2.Signature()
+ for name, tensor in six.iteritems(name_tensor_map):
+ signature.generic_signature.map[name].tensor_name = tensor.name
+ return signature
+
+
+class Exporter(object):
+ """Exporter helps package a TensorFlow model for serving.
+
+ Args:
+ saver: Saver object.
+ """
+
+ def __init__(self, saver):
+ self._saver = saver
+ self._has_init = False
+ self._assets_to_copy = {}
+
+ def init(self,
+ graph_def=None,
+ init_op=None,
+ clear_devices=False,
+ default_graph_signature=None,
+ named_graph_signatures=None,
+ assets_collection=None,
+ assets_callback=gfile_copy_callback):
+ """Initialization.
+
+ Args:
+ graph_def: A GraphDef message of the graph to be used in inference.
+ GraphDef of default graph is used when None.
+ init_op: Op to be used in initialization.
+ clear_devices: If device info of the graph should be cleared upon export.
+ default_graph_signature: Default signature of the graph.
+ named_graph_signatures: Map of named input/output signatures of the graph.
+ assets_collection: A collection of constant asset filepath tensors. If set
+ the assets will be exported into the asset directory.
+ assets_callback: callback with two argument called during export with the
+ list of files to copy and the asset path.
+ Raises:
+ RuntimeError: if init is called more than once.
+ TypeError: if init_op is not an Operation or None.
+ ValueError: if asset file path tensors are not non-empty constant string
+ scalar tensors.
+ """
+ # Avoid Dangerous default value []
+ if named_graph_signatures is None:
+ named_graph_signatures = {}
+ assets = []
+ if assets_collection:
+ for asset_tensor in assets_collection:
+ asset_filepath = self._file_path_value(asset_tensor)
+ if not asset_filepath:
+ raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
+ basename = os.path.basename(asset_filepath)
+ assets.append((basename, asset_tensor))
+ self._assets_to_copy[asset_filepath] = basename
+
+ if self._has_init:
+ raise RuntimeError("init should be called only once")
+ self._has_init = True
+
+ if graph_def or clear_devices:
+ copy = tf.GraphDef()
+ if graph_def:
+ copy.CopyFrom(graph_def)
+ else:
+ copy.CopyFrom(tf.get_default_graph().as_graph_def())
+ if clear_devices:
+ for node in copy.node:
+ node.device = ""
+ graph_any_buf = Any()
+ graph_any_buf.Pack(copy)
+ tf.add_to_collection(GRAPH_KEY, graph_any_buf)
+
+ if init_op:
+ if not isinstance(init_op, ops.Operation):
+ raise TypeError("init_op needs to be an Operation: %s" % init_op)
+ tf.add_to_collection(INIT_OP_KEY, init_op)
+
+ signatures_proto = manifest_pb2.Signatures()
+ if default_graph_signature:
+ signatures_proto.default_signature.CopyFrom(default_graph_signature)
+ for signature_name, signature in six.iteritems(named_graph_signatures):
+ signatures_proto.named_signatures[signature_name].CopyFrom(signature)
+ signatures_any_buf = Any()
+ signatures_any_buf.Pack(signatures_proto)
+ tf.add_to_collection(SIGNATURES_KEY, signatures_any_buf)
+
+ for filename, tensor in assets:
+ asset = manifest_pb2.AssetFile()
+ asset.filename = filename
+ asset.tensor_binding.tensor_name = tensor.name
+ asset_any_buf = Any()
+ asset_any_buf.Pack(asset)
+ tf.add_to_collection(ASSETS_KEY, asset_any_buf)
+
+ self._assets_callback = assets_callback
+
+ def export(self,
+ export_dir_base,
+ global_step_tensor,
+ sess=None,
+ exports_to_keep=None):
+ """Exports the model.
+
+ Args:
+ export_dir_base: A string path to the base export dir.
+ global_step_tensor: An Tensor or tensor name providing the
+ global step counter to append to the export directory path and set
+ in the manifest version.
+ sess: A Session to use to save the parameters.
+ exports_to_keep: a gc.Path filter function used to determine the set of
+ exports to keep. If set to None, all versions will be kept.
+
+ Returns:
+ The string path to the exported directory.
+
+ Raises:
+ RuntimeError: if init is not called.
+ RuntimeError: if the export would overwrite an existing directory.
+ """
+ if not self._has_init:
+ raise RuntimeError("init must be called first")
+
+ global_step = training_util.global_step(sess, global_step_tensor)
+ export_dir = os.path.join(
+ compat.as_bytes(export_dir_base),
+ compat.as_bytes(VERSION_FORMAT_SPECIFIER % global_step))
+
+ # Prevent overwriting on existing exports which could lead to bad/corrupt
+ # storage and loading of models. This is an important check that must be
+ # done before any output files or directories are created.
+ if gfile.Exists(export_dir):
+ raise RuntimeError("Overwriting exports can cause corruption and are "
+ "not allowed. Duplicate export dir: %s" % export_dir)
+
+ # Output to a temporary directory which is atomically renamed to the final
+ # directory when complete.
+ tmp_export_dir = compat.as_text(export_dir) + "-tmp"
+ gfile.MakeDirs(tmp_export_dir)
+
+ self._saver.save(sess,
+ os.path.join(
+ compat.as_text(tmp_export_dir),
+ compat.as_text(EXPORT_BASE_NAME)),
+ meta_graph_suffix=EXPORT_SUFFIX_NAME)
+
+ # Run the asset callback.
+ if self._assets_callback and self._assets_to_copy:
+ assets_dir = os.path.join(
+ compat.as_bytes(tmp_export_dir), compat.as_bytes(ASSETS_DIRECTORY))
+ gfile.MakeDirs(assets_dir)
+ self._assets_callback(self._assets_to_copy, assets_dir)
+
+ # TODO(b/27794910): Delete *checkpoint* file before rename.
+ gfile.Rename(tmp_export_dir, export_dir)
+
+ if exports_to_keep:
+ # create a simple parser that pulls the export_version from the directory.
+ def parser(path):
+ match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
+ if not match:
+ return None
+ return path._replace(export_version=int(match.group(1)))
+
+ paths_to_delete = gc.negation(exports_to_keep)
+ for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
+ gfile.DeleteRecursively(p.path)
+
+ return export_dir
+
+ def _file_path_value(self, path_tensor):
+ """Returns the filepath value stored in constant `path_tensor`."""
+ if not isinstance(path_tensor, tf.Tensor):
+ raise TypeError("tensor is not a Tensor")
+ if path_tensor.op.type != "Const":
+ raise TypeError("Only constants tensor are supported")
+ if path_tensor.dtype != tf.string:
+ raise TypeError("File paths should be string")
+ str_value = path_tensor.op.get_attr("value").string_val
+ if len(str_value) != 1:
+ raise TypeError("Only scalar tensors are supported")
+ return str_value[0]