aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/exporter.py
blob: dcc7fbaa2d61a48eddb5d8495360cb5559c99f4c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export a TensorFlow model.

See: go/tf-exporter
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import re
import six

from google.protobuf.any_pb2 import Any

from tensorflow.contrib.session_bundle import constants
from tensorflow.contrib.session_bundle import gc
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.training import training_util
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def gfile_copy_callback(files_to_copy, export_dir_path):
  """Callback to copy files using `gfile.Copy` to an export directory.

  This method is used as the default `assets_callback` in `Exporter.init` to
  copy assets from the `assets_collection`. It can also be invoked directly to
  copy additional supplementary files into the export directory (in which case
  it is not a callback).

  Args:
    files_to_copy: A dictionary that maps original file paths to desired
      basename in the export directory.
    export_dir_path: Directory to copy the files to.
  """
  logging.info("Write assets into: %s using gfile_copy.", export_dir_path)
  gfile.MakeDirs(export_dir_path)
  for source_filepath, basename in files_to_copy.items():
    new_path = os.path.join(
        compat.as_bytes(export_dir_path), compat.as_bytes(basename))
    logging.info("Copying asset %s to path %s.", source_filepath, new_path)

    if gfile.Exists(new_path):
      # Guard against being restarted while copying assets, and the file
      # existing and being in an unknown state.
      # TODO(b/28676216): Do some file checks before deleting.
      logging.info("Removing file %s.", new_path)
      gfile.Remove(new_path)
    gfile.Copy(source_filepath, new_path)


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def regression_signature(input_tensor, output_tensor):
  """Creates a regression signature.

  Args:
    input_tensor: Tensor specifying the input to a graph.
    output_tensor: Tensor specifying the output of a graph.

  Returns:
    A Signature message.
  """
  signature = manifest_pb2.Signature()
  signature.regression_signature.input.tensor_name = input_tensor.name
  signature.regression_signature.output.tensor_name = output_tensor.name
  return signature


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def classification_signature(input_tensor,
                             classes_tensor=None,
                             scores_tensor=None):
  """Creates a classification signature.

  Args:
    input_tensor: Tensor specifying the input to a graph.
    classes_tensor: Tensor specifying the output classes of a graph.
    scores_tensor: Tensor specifying the scores of the output classes.

  Returns:
    A Signature message.
  """
  signature = manifest_pb2.Signature()
  signature.classification_signature.input.tensor_name = input_tensor.name
  if classes_tensor is not None:
    signature.classification_signature.classes.tensor_name = classes_tensor.name
  if scores_tensor is not None:
    signature.classification_signature.scores.tensor_name = scores_tensor.name
  return signature


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def generic_signature(name_tensor_map):
  """Creates a generic signature of name to Tensor name.

  Args:
    name_tensor_map: Map from logical name to Tensor.

  Returns:
    A Signature message.
  """
  signature = manifest_pb2.Signature()
  for name, tensor in six.iteritems(name_tensor_map):
    signature.generic_signature.map[name].tensor_name = tensor.name
  return signature


class Exporter(object):
  """Exporter helps package a TensorFlow model for serving.

  Args:
    saver: Saver object.
  """

  def __init__(self, saver):
    # Makes a copy of the saver-def and disables garbage-collection, since the
    # exporter enforces garbage-collection independently. Specifically, since
    # the exporter performs atomic copies of the saver output, it is required
    # that garbage-collection via the underlying saver be disabled.
    saver_def = saver.as_saver_def()
    saver_def.ClearField("max_to_keep")
    self._saver = tf_saver.Saver(saver_def=saver_def)
    self._has_init = False
    self._assets_to_copy = {}

  @deprecated("2017-06-30",
              "No longer supported. Switch to SavedModel immediately.")
  def init(self,
           graph_def=None,
           init_op=None,
           clear_devices=False,
           default_graph_signature=None,
           named_graph_signatures=None,
           assets_collection=None,
           assets_callback=gfile_copy_callback):
    """Initialization.

    Args:
      graph_def: A GraphDef message of the graph to be used in inference.
        GraphDef of default graph is used when None.
      init_op: Op to be used in initialization.
      clear_devices: If device info of the graph should be cleared upon export.
      default_graph_signature: Default signature of the graph.
      named_graph_signatures: Map of named input/output signatures of the graph.
      assets_collection: A collection of constant asset filepath tensors. If set
        the assets will be exported into the asset directory.
      assets_callback: callback with two argument called during export with the
        list of files to copy and the asset path.
    Raises:
      RuntimeError: if init is called more than once.
      TypeError: if init_op is not an Operation or None.
      ValueError: if asset file path tensors are not non-empty constant string
        scalar tensors.
    """
    # Avoid Dangerous default value []
    if named_graph_signatures is None:
      named_graph_signatures = {}
    assets = []
    if assets_collection:
      for asset_tensor in assets_collection:
        asset_filepath = self._file_path_value(asset_tensor)
        if not asset_filepath:
          raise ValueError("invalid asset filepath tensor %s" % asset_tensor)
        basename = os.path.basename(asset_filepath)
        assets.append((basename, asset_tensor))
        self._assets_to_copy[asset_filepath] = basename

    if self._has_init:
      raise RuntimeError("init should be called only once")
    self._has_init = True

    if graph_def or clear_devices:
      copy = graph_pb2.GraphDef()
      if graph_def:
        copy.CopyFrom(graph_def)
      else:
        copy.CopyFrom(ops.get_default_graph().as_graph_def())
      if clear_devices:
        for node in copy.node:
          node.device = ""
      graph_any_buf = Any()
      graph_any_buf.Pack(copy)
      ops.add_to_collection(constants.GRAPH_KEY, graph_any_buf)

    if init_op:
      if not isinstance(init_op, ops.Operation):
        raise TypeError("init_op needs to be an Operation: %s" % init_op)
      ops.add_to_collection(constants.INIT_OP_KEY, init_op)

    signatures_proto = manifest_pb2.Signatures()
    if default_graph_signature:
      signatures_proto.default_signature.CopyFrom(default_graph_signature)
    for signature_name, signature in six.iteritems(named_graph_signatures):
      signatures_proto.named_signatures[signature_name].CopyFrom(signature)
    signatures_any_buf = Any()
    signatures_any_buf.Pack(signatures_proto)
    ops.add_to_collection(constants.SIGNATURES_KEY, signatures_any_buf)

    for filename, tensor in assets:
      asset = manifest_pb2.AssetFile()
      asset.filename = filename
      asset.tensor_binding.tensor_name = tensor.name
      asset_any_buf = Any()
      asset_any_buf.Pack(asset)
      ops.add_to_collection(constants.ASSETS_KEY, asset_any_buf)

    self._assets_callback = assets_callback

  @deprecated("2017-06-30",
              "No longer supported. Switch to SavedModel immediately.")
  def export(self,
             export_dir_base,
             global_step_tensor,
             sess=None,
             exports_to_keep=None):
    """Exports the model.

    Args:
      export_dir_base: A string path to the base export dir.
      global_step_tensor: An Tensor or tensor name providing the
        global step counter to append to the export directory path and set
        in the manifest version.
      sess: A Session to use to save the parameters.
      exports_to_keep: a gc.Path filter function used to determine the set of
        exports to keep. If set to None, all versions will be kept.

    Returns:
      The string path to the exported directory.

    Raises:
      RuntimeError: if init is not called.
      RuntimeError: if the export would overwrite an existing directory.
    """
    if not self._has_init:
      raise RuntimeError("init must be called first")

    # Export dir must not end with / or it will break exports to keep. Strip /.
    if export_dir_base.endswith("/"):
      export_dir_base = export_dir_base[:-1]

    global_step = training_util.global_step(sess, global_step_tensor)
    export_dir = os.path.join(
        compat.as_bytes(export_dir_base),
        compat.as_bytes(constants.VERSION_FORMAT_SPECIFIER % global_step))

    # Prevent overwriting on existing exports which could lead to bad/corrupt
    # storage and loading of models. This is an important check that must be
    # done before any output files or directories are created.
    if gfile.Exists(export_dir):
      raise RuntimeError("Overwriting exports can cause corruption and are "
                         "not allowed. Duplicate export dir: %s" % export_dir)

    # Output to a temporary directory which is atomically renamed to the final
    # directory when complete.
    tmp_export_dir = compat.as_text(export_dir) + "-tmp"
    gfile.MakeDirs(tmp_export_dir)

    self._saver.save(sess,
                     os.path.join(
                         compat.as_text(tmp_export_dir),
                         compat.as_text(constants.EXPORT_BASE_NAME)),
                     meta_graph_suffix=constants.EXPORT_SUFFIX_NAME)

    # Run the asset callback.
    if self._assets_callback and self._assets_to_copy:
      assets_dir = os.path.join(
          compat.as_bytes(tmp_export_dir),
          compat.as_bytes(constants.ASSETS_DIRECTORY))
      gfile.MakeDirs(assets_dir)
      self._assets_callback(self._assets_to_copy, assets_dir)

    # TODO(b/27794910): Delete *checkpoint* file before rename.
    gfile.Rename(tmp_export_dir, export_dir)

    if exports_to_keep:
      # create a simple parser that pulls the export_version from the directory.
      def parser(path):
        match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
        if not match:
          return None
        return path._replace(export_version=int(match.group(1)))

      paths_to_delete = gc.negation(exports_to_keep)
      for p in paths_to_delete(gc.get_paths(export_dir_base, parser=parser)):
        gfile.DeleteRecursively(p.path)

    return export_dir

  def _file_path_value(self, path_tensor):
    """Returns the filepath value stored in constant `path_tensor`."""
    if not isinstance(path_tensor, ops.Tensor):
      raise TypeError("tensor is not a Tensor")
    if path_tensor.op.type != "Const":
      raise TypeError("Only constants tensor are supported")
    if path_tensor.dtype != dtypes.string:
      raise TypeError("File paths should be string")
    str_value = path_tensor.op.get_attr("value").string_val
    if len(str_value) != 1:
      raise TypeError("Only scalar tensors are supported")
    return str_value[0]