# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """SavedModel utils.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from tensorflow.contrib.saved_model.python.saved_model import reader def get_meta_graph_def(saved_model_dir, tag_set): """Gets MetaGraphDef from SavedModel. Returns the MetaGraphDef for the given tag-set and SavedModel directory. Args: saved_model_dir: Directory containing the SavedModel to inspect or execute. tag_set: Group of tag(s) of the MetaGraphDef to load, in string format, separated by ','. For tag-set contains multiple tags, all tags must be passed in. Raises: RuntimeError: An error when the given tag-set does not exist in the SavedModel. Returns: A MetaGraphDef corresponding to the tag-set. """ saved_model = reader.read_saved_model(saved_model_dir) set_of_tags = set(tag_set.split(',')) for meta_graph_def in saved_model.meta_graphs: if set(meta_graph_def.meta_info_def.tags) == set_of_tags: return meta_graph_def raise RuntimeError('MetaGraphDef associated with tag-set ' + tag_set + ' could not be found in SavedModel')