aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sukriti Ramesh <sukritiramesh@google.com>2016-10-14 13:06:14 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-10-14 14:20:08 -0700
commit3373d232d904d44c2192deeda5e20b1f8921bfc2 (patch)
treef85c72e8884e913a61e965e2531907dcfeedcc35
parent3ba17a9a65c0091f48c3c36562dbb59f4733a074 (diff)
Add asset support in SavedModel loader py.
Change: 136195816
-rw-r--r--tensorflow/python/saved_model/loader.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/tensorflow/python/saved_model/loader.py b/tensorflow/python/saved_model/loader.py
index 4bf55dc163..594416287d 100644
--- a/tensorflow/python/saved_model/loader.py
+++ b/tensorflow/python/saved_model/loader.py
@@ -66,6 +66,7 @@ from __future__ import print_function
import os
from google.protobuf import text_format
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saved_model_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
@@ -118,6 +119,37 @@ def _parse_saved_model(export_dir):
return saved_model
+def _get_asset_tensors(export_dir, meta_graph_def_to_load):
+ """Gets the asset tensors, if defined in the meta graph def to load.
+
+ Args:
+ export_dir: Directory where the SavedModel is located.
+ meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
+
+ Returns:
+ A dictionary of asset tensors, keyed by the name of the asset tensor. The
+ value in the map corresponds to the absolute path of the asset file.
+ """
+ # Collection-def that may contain the assets key.
+ collection_def = meta_graph_def_to_load.collection_def
+
+ asset_tensor_dict = {}
+ if constants.ASSETS_KEY in collection_def:
+ # Location of the assets for SavedModel.
+ assets_directory = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.ASSETS_DIRECTORY))
+ assets_any_proto = collection_def[constants.ASSETS_KEY].any_list.value
+ # Process each asset and add it to the asset tensor dictionary.
+ for asset_any_proto in assets_any_proto:
+ asset_proto = meta_graph_pb2.AssetFileDef()
+ asset_any_proto.Unpack(asset_proto)
+ asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
+ compat.as_bytes(assets_directory),
+ compat.as_bytes(asset_proto.filename))
+ return asset_tensor_dict
+
+
def load(sess, tags, export_dir):
"""Loads the model from a SavedModel as specified by tags.
@@ -161,5 +193,8 @@ def load(sess, tags, export_dir):
# Restore the variables using the built saver in the provided session.
saver.restore(sess, variables_path)
+ # Get asset tensors, if any.
+ _get_asset_tensors(export_dir, meta_graph_def_to_load)
+
# Return the meta graph def that was loaded into the session.
return meta_graph_def_to_load