diff options
-rw-r--r-- | tensorflow/python/saved_model/loader.py | 35 |
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/saved_model/loader.py b/tensorflow/python/saved_model/loader.py index 4bf55dc163..594416287d 100644 --- a/tensorflow/python/saved_model/loader.py +++ b/tensorflow/python/saved_model/loader.py @@ -66,6 +66,7 @@ from __future__ import print_function import os from google.protobuf import text_format +from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.core.protobuf import saved_model_pb2 from tensorflow.python.lib.io import file_io from tensorflow.python.saved_model import constants @@ -118,6 +119,37 @@ def _parse_saved_model(export_dir): return saved_model +def _get_asset_tensors(export_dir, meta_graph_def_to_load): + """Gets the asset tensors, if defined in the meta graph def to load. + + Args: + export_dir: Directory where the SavedModel is located. + meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. + + Returns: + A dictionary of asset tensors, keyed by the name of the asset tensor. The + value in the map corresponds to the absolute path of the asset file. + """ + # Collection-def that may contain the assets key. + collection_def = meta_graph_def_to_load.collection_def + + asset_tensor_dict = {} + if constants.ASSETS_KEY in collection_def: + # Location of the assets for SavedModel. + assets_directory = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.ASSETS_DIRECTORY)) + assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value + # Process each asset and add it to the asset tensor dictionary. + for asset_any_proto in assets_any_proto: + asset_proto = meta_graph_pb2.AssetFileDef() + asset_any_proto.Unpack(asset_proto) + asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join( + compat.as_bytes(assets_directory), + compat.as_bytes(asset_proto.filename)) + return asset_tensor_dict + + def load(sess, tags, export_dir): """Loads the model from a SavedModel as specified by tags. @@ -161,5 +193,8 @@ def load(sess, tags, export_dir): # Restore the variables using the built saver in the provided session. saver.restore(sess, variables_path) + # Get asset tensors, if any. + _get_asset_tensors(export_dir, meta_graph_def_to_load) + # Return the meta graph def that was loaded into the session. return meta_graph_def_to_load |