aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/session_bundle/session_bundle.py
blob: 66f2e32f58ea5c17a1225e0c77a6d7db6d22edd4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Importer for an exported TensorFlow model.

This module provides a function to create a SessionBundle containing both the
Session and MetaGraph.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os

from tensorflow.contrib.session_bundle import constants
from tensorflow.contrib.session_bundle import manifest_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.training import saver as saver_lib
from tensorflow.python.util.deprecation import deprecated


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def maybe_session_bundle_dir(export_dir):
  """Checks if the model path contains session bundle model.

  Args:
    export_dir: string path to model checkpoint, for example 'model/00000123'

  Returns:
    true if path contains session bundle model files, ie META_GRAPH_DEF_FILENAME
  """

  meta_graph_filename = os.path.join(export_dir,
                                     constants.META_GRAPH_DEF_FILENAME)
  return file_io.file_exists(meta_graph_filename)


@deprecated("2017-06-30",
            "No longer supported. Switch to SavedModel immediately.")
def load_session_bundle_from_path(export_dir,
                                  target="",
                                  config=None,
                                  meta_graph_def=None):
  """Load session bundle from the given path.

  The function reads input from the export_dir, constructs the graph data to the
  default graph and restores the parameters for the session created.

  Args:
    export_dir: the directory that contains files exported by exporter.
    target: The execution engine to connect to. See target in tf.Session()
    config: A ConfigProto proto with configuration options. See config in
    tf.Session()
    meta_graph_def: optional object of type MetaGraphDef. If this object is
    present, then it is used instead of parsing MetaGraphDef from export_dir.

  Returns:
    session: a tensorflow session created from the variable files.
    meta_graph: a meta graph proto saved in the exporter directory.

  Raises:
    RuntimeError: if the required files are missing or contain unrecognizable
    fields, i.e. the exported model is invalid.
  """
  if not meta_graph_def:
    meta_graph_filename = os.path.join(export_dir,
                                       constants.META_GRAPH_DEF_FILENAME)
    if not file_io.file_exists(meta_graph_filename):
      raise RuntimeError("Expected meta graph file missing %s" %
                         meta_graph_filename)
    # Reads meta graph file.
    meta_graph_def = meta_graph_pb2.MetaGraphDef()
    meta_graph_def.ParseFromString(
        file_io.read_file_to_string(meta_graph_filename, binary_mode=True))

  variables_filename = ""
  variables_filename_list = []
  checkpoint_sharded = False

  variables_index_filename = os.path.join(export_dir,
                                          constants.VARIABLES_INDEX_FILENAME_V2)
  checkpoint_v2 = file_io.file_exists(variables_index_filename)

  # Find matching checkpoint files.
  if checkpoint_v2:
    # The checkpoint is in v2 format.
    variables_filename_pattern = os.path.join(
        export_dir, constants.VARIABLES_FILENAME_PATTERN_V2)
    variables_filename_list = file_io.get_matching_files(
        variables_filename_pattern)
    checkpoint_sharded = True
  else:
    variables_filename = os.path.join(export_dir, constants.VARIABLES_FILENAME)
    if file_io.file_exists(variables_filename):
      variables_filename_list = [variables_filename]
    else:
      variables_filename = os.path.join(export_dir,
                                        constants.VARIABLES_FILENAME_PATTERN)
      variables_filename_list = file_io.get_matching_files(variables_filename)
      checkpoint_sharded = True

  # Prepare the files to restore a session.
  if not variables_filename_list:
    restore_files = ""
  elif checkpoint_v2 or not checkpoint_sharded:
    # For checkpoint v2 or v1 with non-sharded files, use "export" to restore
    # the session.
    restore_files = constants.VARIABLES_FILENAME
  else:
    restore_files = constants.VARIABLES_FILENAME_PATTERN

  assets_dir = os.path.join(export_dir, constants.ASSETS_DIRECTORY)

  collection_def = meta_graph_def.collection_def
  graph_def = graph_pb2.GraphDef()
  if constants.GRAPH_KEY in collection_def:
    # Use serving graph_def in MetaGraphDef collection_def if exists
    graph_def_any = collection_def[constants.GRAPH_KEY].any_list.value
    if len(graph_def_any) != 1:
      raise RuntimeError("Expected exactly one serving GraphDef in : %s" %
                         meta_graph_def)
    else:
      graph_def_any[0].Unpack(graph_def)
      # Replace the graph def in meta graph proto.
      meta_graph_def.graph_def.CopyFrom(graph_def)

  ops.reset_default_graph()
  sess = session.Session(target, graph=None, config=config)
  # Import the graph.
  saver = saver_lib.import_meta_graph(meta_graph_def)
  # Restore the session.
  if restore_files:
    saver.restore(sess, os.path.join(export_dir, restore_files))

  init_op_tensor = None
  if constants.INIT_OP_KEY in collection_def:
    init_ops = collection_def[constants.INIT_OP_KEY].node_list.value
    if len(init_ops) != 1:
      raise RuntimeError("Expected exactly one serving init op in : %s" %
                         meta_graph_def)
    init_op_tensor = ops.get_collection(constants.INIT_OP_KEY)[0]

  # Create asset input tensor list.
  asset_tensor_dict = {}
  if constants.ASSETS_KEY in collection_def:
    assets_any = collection_def[constants.ASSETS_KEY].any_list.value
    for asset in assets_any:
      asset_pb = manifest_pb2.AssetFile()
      asset.Unpack(asset_pb)
      asset_tensor_dict[asset_pb.tensor_binding.tensor_name] = os.path.join(
          assets_dir, asset_pb.filename)

  if init_op_tensor:
    # Run the init op.
    sess.run(fetches=[init_op_tensor], feed_dict=asset_tensor_dict)

  return sess, meta_graph_def