aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model/loader_impl.py
blob: 895644a0305d99f086a3682d41c90e6af0be32d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loader implementation for SavedModel with hermetic, language-neutral exports.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from google.protobuf import message
from google.protobuf import text_format

from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export


def _parse_saved_model(export_dir):
  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.

  Args:
    export_dir: Directory containing the SavedModel file.

  Returns:
    A `SavedModel` protocol buffer.

  Raises:
    IOError: If the file does not exist, or cannot be successfully parsed.
  """
  # Build the path to the SavedModel in pbtxt format.
  path_to_pbtxt = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
  # Build the path to the SavedModel in pb format.
  path_to_pb = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))

  # Parse the SavedModel protocol buffer.
  saved_model = saved_model_pb2.SavedModel()
  if file_io.file_exists(path_to_pb):
    try:
      file_content = file_io.FileIO(path_to_pb, "rb").read()
      saved_model.ParseFromString(file_content)
      return saved_model
    except message.DecodeError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
  elif file_io.file_exists(path_to_pbtxt):
    try:
      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
      text_format.Merge(file_content.decode("utf-8"), saved_model)
      return saved_model
    except text_format.ParseError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
  else:
    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
                  (export_dir,
                   constants.SAVED_MODEL_FILENAME_PBTXT,
                   constants.SAVED_MODEL_FILENAME_PB))


def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
  """Gets the asset tensors, if defined in the meta graph def to load.

  Args:
    export_dir: Directory where the SavedModel is located.
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
    import_scope: Optional `string` -- if specified, prepend this followed by
        '/' to all returned asset tensor names.

  Returns:
    A dictionary of asset tensors, keyed by the name of the asset tensor. The
    value in the map corresponds to the absolute path of the asset file.
  """
  # Collection-def that may contain the assets key.
  collection_def = meta_graph_def_to_load.collection_def

  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    # Location of the assets for SavedModel.
    assets_directory = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.ASSETS_DIRECTORY))
    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
    # Process each asset and add it to the asset tensor dictionary.
    for asset_any_proto in assets_any_proto:
      asset_proto = meta_graph_pb2.AssetFileDef()
      asset_any_proto.Unpack(asset_proto)
      tensor_name = asset_proto.tensor_info.name
      if import_scope:
        tensor_name = "%s/%s" % (import_scope, tensor_name)
      asset_tensor_dict[tensor_name] = os.path.join(
          compat.as_bytes(assets_directory),
          compat.as_bytes(asset_proto.filename))
  return asset_tensor_dict


def _get_main_op_tensor(
    meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
  """Gets the main op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
    init_op_key: name of collection to check; should be one of MAIN_OP_KEY
      or the deprecated LEGACY_INIT_OP_KEY

  Returns:
    The main op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the main op key has
        other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  main_op_tensor = None
  if init_op_key in collection_def:
    main_ops = collection_def[init_op_key].node_list.value
    if len(main_ops) != 1:
      raise RuntimeError("Expected exactly one SavedModel main op. "
                         "Found: {}".format(main_ops))
    main_op_tensor = ops.get_collection(init_op_key)[0]
  return main_op_tensor


@tf_export("saved_model.maybe_saved_model_directory",
           "saved_model.loader.maybe_saved_model_directory")
@deprecation.deprecated_endpoints(
    "saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
  """Checks whether the provided export directory could contain a SavedModel.

  Note that the method does not load any data by itself. If the method returns
  `false`, the export directory definitely does not contain a SavedModel. If the
  method returns `true`, the export directory may contain a SavedModel but
  provides no guarantee that it can be loaded.

  Args:
    export_dir: Absolute string path to possible export location. For example,
                '/my/foo/model'.

  Returns:
    True if the export directory contains SavedModel files, False otherwise.
  """
  txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
  pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)


@tf_export("saved_model.load", "saved_model.loader.load")
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.
    import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.
    **saver_kwargs: Optional keyword arguments passed through to Saver.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  loader = SavedModelLoader(export_dir)
  return loader.load(sess, tags, import_scope, **saver_kwargs)


class SavedModelLoader(object):
  """Load graphs and restore variable values from a `SavedModel`."""

  def __init__(self, export_dir):
    """Creates a `SavedModelLoader`.

    Args:
      export_dir: Directory in which the SavedModel protocol buffer and
        variables to be loaded are located.
    """
    self._export_dir = export_dir
    self._variables_path = saved_model_utils.get_variables_path(export_dir)
    self._saved_model = _parse_saved_model(export_dir)

  @property
  def export_dir(self):
    """Directory containing the SavedModel."""
    return self._export_dir

  @property
  def variables_path(self):
    """Path to variable checkpoint files."""
    return self._variables_path

  @property
  def saved_model(self):
    """SavedModel object parsed from the export directory."""
    return self._saved_model

  def get_meta_graph_def_from_tags(self, tags):
    """Return MetaGraphDef with the exact specified tags.

    Args:
      tags: A list or set of string tags that identify the MetaGraphDef.

    Returns:
      MetaGraphDef with the same tags.

    Raises:
      RuntimeError: if no metagraphs were found with the associated tags.
    """
    found_match = False
    for meta_graph_def in self._saved_model.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(tags):
        meta_graph_def_to_load = meta_graph_def
        found_match = True
        break

    if not found_match:
      raise RuntimeError(
          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
          " could not be found in SavedModel. To inspect available tag-sets in"
          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
      )
    return meta_graph_def_to_load

  def load_graph(self, graph, tags, import_scope=None, **saver_kwargs):
    """Load ops and nodes from SavedModel MetaGraph into graph.

    Args:
      graph: tf.Graph object.
      tags: a set of string tags identifying a MetaGraphDef.
      import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.
      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.

    Returns:
      A tuple of
        * Saver defined by the MetaGraph, which can be used to restore the
          variable values.
        * List of `Operation`/`Tensor` objects returned from
          `tf.import_graph_def` (may be `None`).
    """
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    with graph.as_default():
      return tf_saver._import_meta_graph_with_return_elements(  # pylint: disable=protected-access
          meta_graph_def, import_scope=import_scope, **saver_kwargs)

  def restore_variables(self, sess, saver, import_scope=None):
    """Restore SavedModel variable values into the session.

    Args:
      sess: tf.Session to restore variable values.
      saver: a tf.train.Saver object. Can be None if there are no variables in
        graph. This may be the saver returned by the load_graph() function, or a
        default `tf.train.Saver()`.
      import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.

    Raises:
      ValueError: if no saver was passed to the saver argument, and there are
        variables in the graph.
    """
    with sess.graph.as_default():
      if (saver is None and
          not variables._all_saveable_objects(scope=import_scope)):  # pylint: disable=protected-access
        tf_logging.info("The specified SavedModel has no variables; no "
                        "checkpoints were restored.")
      elif isinstance(saver, tf_saver.Saver):
        saver.restore(sess, self._variables_path)
      else:
        raise ValueError(
            "No tf.train.Saver object was passed to the function "
            "SavedModelLoader.restore_variables. Since there are variables in "
            "the graph, a saver is required.")

  def run_init_ops(self, sess, tags, import_scope=None):
    """Run initialization ops defined in the `MetaGraphDef`.

    Args:
      sess: tf.Session to restore variable values.
      tags: a set of string tags identifying a MetaGraphDef.
      import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.
    """
    meta_graph_def = self.get_meta_graph_def_from_tags(tags)
    with sess.graph.as_default():
      # Get asset tensors, if any.
      asset_tensors_dictionary = _get_asset_tensors(
          self._export_dir, meta_graph_def, import_scope=import_scope)

      main_op_tensor = (
          _get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or
          _get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
      if main_op_tensor is not None:
        sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)

  def load(self, sess, tags, import_scope=None, **saver_kwargs):
    """Load the MetaGraphDef graph and restore variable values into the session.

    Args:
      sess: tf.Session to restore variable values.
      tags: a set of string tags identifying a MetaGraphDef.
      import_scope: Optional `string` -- if specified, prepend this string
        followed by '/' to all loaded tensor names. This scope is applied to
        tensor instances loaded into the passed session, but it is *not* written
        through to the static `MetaGraphDef` protocol buffer that is returned.
      **saver_kwargs: keyword arguments to pass to tf.train.import_meta_graph.

    Returns:
      `MetagraphDef` proto of the graph that was loaded.
    """
    with sess.graph.as_default():
      saver, _ = self.load_graph(sess.graph, tags, import_scope,
                                 **saver_kwargs)
      self.restore_variables(sess, saver, import_scope)
      self.run_init_ops(sess, tags, import_scope)
    return self.get_meta_graph_def_from_tags(tags)