aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model/loader_impl.py
blob: ddfd6be6dae258d511d085c1bc0e4f1c99d01424 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Loader implementation for SavedModel with hermetic, language-neutral exports.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from google.protobuf import message
from google.protobuf import text_format

from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export


def _parse_saved_model(export_dir):
  """Reads the savedmodel.pb or savedmodel.pbtxt file containing `SavedModel`.

  Args:
    export_dir: Directory containing the SavedModel file.

  Returns:
    A `SavedModel` protocol buffer.

  Raises:
    IOError: If the file does not exist, or cannot be successfully parsed.
  """
  # Build the path to the SavedModel in pbtxt format.
  path_to_pbtxt = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
  # Build the path to the SavedModel in pb format.
  path_to_pb = os.path.join(
      compat.as_bytes(export_dir),
      compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))

  # Parse the SavedModel protocol buffer.
  saved_model = saved_model_pb2.SavedModel()
  if file_io.file_exists(path_to_pb):
    try:
      file_content = file_io.FileIO(path_to_pb, "rb").read()
      saved_model.ParseFromString(file_content)
      return saved_model
    except message.DecodeError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pb, str(e)))
  elif file_io.file_exists(path_to_pbtxt):
    try:
      file_content = file_io.FileIO(path_to_pbtxt, "rb").read()
      text_format.Merge(file_content.decode("utf-8"), saved_model)
      return saved_model
    except text_format.ParseError as e:
      raise IOError("Cannot parse file %s: %s." % (path_to_pbtxt, str(e)))
  else:
    raise IOError("SavedModel file does not exist at: %s/{%s|%s}" %
                  (export_dir,
                   constants.SAVED_MODEL_FILENAME_PBTXT,
                   constants.SAVED_MODEL_FILENAME_PB))


def _get_asset_tensors(export_dir, meta_graph_def_to_load):
  """Gets the asset tensors, if defined in the meta graph def to load.

  Args:
    export_dir: Directory where the SavedModel is located.
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    A dictionary of asset tensors, keyed by the name of the asset tensor. The
    value in the map corresponds to the absolute path of the asset file.
  """
  # Collection-def that may contain the assets key.
  collection_def = meta_graph_def_to_load.collection_def

  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    # Location of the assets for SavedModel.
    assets_directory = os.path.join(
        compat.as_bytes(export_dir),
        compat.as_bytes(constants.ASSETS_DIRECTORY))
    assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
    # Process each asset and add it to the asset tensor dictionary.
    for asset_any_proto in assets_any_proto:
      asset_proto = meta_graph_pb2.AssetFileDef()
      asset_any_proto.Unpack(asset_proto)
      asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
          compat.as_bytes(assets_directory),
          compat.as_bytes(asset_proto.filename))
  return asset_tensor_dict


def _get_main_op_tensor(meta_graph_def_to_load):
  """Gets the main op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The main op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the main op key has
        other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  main_op_tensor = None
  if constants.MAIN_OP_KEY in collection_def:
    main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
    if len(main_ops) != 1:
      raise RuntimeError("Expected exactly one SavedModel main op.")
    main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
  return main_op_tensor


def _get_legacy_init_op_tensor(meta_graph_def_to_load):
  """Gets the legacy init op tensor, if one exists.

  Args:
    meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.

  Returns:
    The legacy init op tensor, if it exists and `None` otherwise.

  Raises:
    RuntimeError: If the collection def corresponding to the legacy init op key
        has other than exactly one tensor.
  """
  collection_def = meta_graph_def_to_load.collection_def
  legacy_init_op_tensor = None
  if constants.LEGACY_INIT_OP_KEY in collection_def:
    legacy_init_ops = collection_def[
        constants.LEGACY_INIT_OP_KEY].node_list.value
    if len(legacy_init_ops) != 1:
      raise RuntimeError("Expected exactly one legacy serving init op.")
    legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
  return legacy_init_op_tensor


@tf_export("saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
  """Checks whether the provided export directory could contain a SavedModel.

  Note that the method does not load any data by itself. If the method returns
  `false`, the export directory definitely does not contain a SavedModel. If the
  method returns `true`, the export directory may contain a SavedModel but
  provides no guarantee that it can be loaded.

  Args:
    export_dir: Absolute string path to possible export location. For example,
                '/my/foo/model'.

  Returns:
    True if the export directory contains SavedModel files, False otherwise.
  """
  txt_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PBTXT)
  pb_path = os.path.join(export_dir, constants.SAVED_MODEL_FILENAME_PB)
  return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)


@tf_export("saved_model.loader.load")
def load(sess, tags, export_dir, **saver_kwargs):
  """Loads the model from a SavedModel as specified by tags.

  Args:
    sess: The TensorFlow session to restore the variables.
    tags: Set of string tags to identify the required MetaGraphDef. These should
        correspond to the tags used when saving the variables using the
        SavedModel `save()` API.
    export_dir: Directory in which the SavedModel protocol buffer and variables
        to be loaded are located.
    **saver_kwargs: Optional keyword arguments passed through to Saver.

  Returns:
    The `MetaGraphDef` protocol buffer loaded in the provided session. This
    can be used to further extract signature-defs, collection-defs, etc.

  Raises:
    RuntimeError: MetaGraphDef associated with the tags cannot be found.
  """
  with sess.graph.as_default():
    # Build the SavedModel protocol buffer and find requested meta graph def.
    saved_model = _parse_saved_model(export_dir)
    found_match = False
    for meta_graph_def in saved_model.meta_graphs:
      if set(meta_graph_def.meta_info_def.tags) == set(tags):
        meta_graph_def_to_load = meta_graph_def
        found_match = True
        break

    if not found_match:
      raise RuntimeError(
          "MetaGraphDef associated with tags " + str(tags).strip("[]") +
          " could not be found in SavedModel. To inspect available tag-sets in"
          " the SavedModel, please use the SavedModel CLI: `saved_model_cli`"
      )

    # Build a saver by importing the meta graph def to load.
    saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)

    if saver:
      # Build the checkpoint path where the variables are located.
      variables_path = os.path.join(
          compat.as_bytes(export_dir),
          compat.as_bytes(constants.VARIABLES_DIRECTORY),
          compat.as_bytes(constants.VARIABLES_FILENAME))

      # Restore the variables using the built saver in the provided session.
      saver.restore(sess, variables_path)
    else:
      tf_logging.info("The specified SavedModel has no variables; no "
                      "checkpoints were restored.")

    # Get asset tensors, if any.
    asset_tensors_dictionary = _get_asset_tensors(export_dir,
                                                  meta_graph_def_to_load)

    main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
    if main_op_tensor is not None:
      sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
    else:
      legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
      if legacy_init_op_tensor is not None:
        sess.run(
            fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)

    return meta_graph_def_to_load