aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-12 11:26:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 11:28:58 -0700
commitdc7821ccf42ada3f85ca1c6e8228f0a42e61b93c (patch)
tree411f30d2cd5037352f593e551a2205a1629abee6 /tensorflow/python/saved_model
parentba9422a8adba18fc97cc1923002b7db8ca63dcfe (diff)
Apply import_scope to asset and variable tensors during tf.saved_model.loader.load
This change explicitly declares import_scope as a kwarg for tf.saved_model.loader.load. Previously, tf.saved_model.loader.load implicitly accepted import_scope and passed it through to import_meta_graph through **saver_kwargs. PiperOrigin-RevId: 200249417
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/loader_impl.py22
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py53
2 files changed, 69 insertions, 6 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index bebf1d5e0d..d1bd8d47ae 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -79,12 +79,14 @@ def _parse_saved_model(export_dir):
constants.SAVED_MODEL_FILENAME_PB))
-def _get_asset_tensors(export_dir, meta_graph_def_to_load):
+def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
"""Gets the asset tensors, if defined in the meta graph def to load.
Args:
export_dir: Directory where the SavedModel is located.
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
+ import_scope: Optional `string` -- if specified, prepend this followed by
+ '/' to all returned asset tensor names.
Returns:
A dictionary of asset tensors, keyed by the name of the asset tensor. The
@@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load):
for asset_any_proto in assets_any_proto:
asset_proto = meta_graph_pb2.AssetFileDef()
asset_any_proto.Unpack(asset_proto)
- asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join(
+ tensor_name = asset_proto.tensor_info.name
+ if import_scope:
+ tensor_name = "%s/%s" % (import_scope, tensor_name)
+ asset_tensor_dict[tensor_name] = os.path.join(
compat.as_bytes(assets_directory),
compat.as_bytes(asset_proto.filename))
return asset_tensor_dict
@@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir):
@tf_export("saved_model.loader.load")
-def load(sess, tags, export_dir, **saver_kwargs):
+def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
Args:
@@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs):
SavedModel `save()` API.
export_dir: Directory in which the SavedModel protocol buffer and variables
to be loaded are located.
+ import_scope: Optional `string` -- if specified, prepend this string
+ followed by '/' to all loaded tensor names. This scope is applied to
+ tensor instances loaded into the passed session, but it is *not* written
+ through to the static `MetaGraphDef` protocol buffer that is returned.
**saver_kwargs: Optional keyword arguments passed through to Saver.
Returns:
@@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
)
# Build a saver by importing the meta graph def to load.
- saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
+ saver = tf_saver.import_meta_graph(
+ meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs)
if saver:
# Build the checkpoint path where the variables are located.
@@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs):
"checkpoints were restored.")
# Get asset tensors, if any.
- asset_tensors_dictionary = _get_asset_tensors(export_dir,
- meta_graph_def_to_load)
+ asset_tensors_dictionary = _get_asset_tensors(
+ export_dir, meta_graph_def_to_load, import_scope=import_scope)
main_op_tensor = (
_get_main_op_tensor(meta_graph_def_to_load) or
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index effb38283b..fb4732aca2 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase):
_validate_custom_saver("tag_1", "save_1/restore_all")
_validate_custom_saver("tag_2", "save_2/restore_all")
+ def testImportScope(self):
+ export_dir = self._get_export_dir("test_scoped_assets")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Build a SavedModel with a variable, an asset, and a constant tensor.
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+ asset_collection = self._build_asset_collection("foo.txt", "content_foo",
+ "asset_file_tensor")
+ constant_op.constant("constant value", name="constant_tensor_name")
+ builder.add_meta_graph_and_variables(
+ sess, ["tag_name"], assets_collection=asset_collection)
+
+ # Save the asset file path for later comparison.
+ asset_file_path = asset_collection[0].eval()
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ # Restore the SavedModel under an import_scope in a new graph/session.
+ graph_proto = loader.load(
+ sess, ["tag_name"], export_dir, import_scope="scope_name")
+
+ # The loaded variable tensor should be scoped, but its contents should be
+ # unchanged.
+ self.assertEqual(
+ "scope_name/v:0",
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name)
+ self.assertEqual(
+ 42,
+ ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+
+ # The loaded asset tensor should be scoped, but the asset file path and
+ # contents should be unchanged.
+ asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
+ self.assertEqual(1, len(asset_collection))
+ self.assertEqual(asset_file_path, asset_collection[0].eval())
+ self.assertEqual("scope_name/asset_file_tensor:0",
+ asset_collection[0].name)
+ # The static asset data inside graph_proto.collection_def should not be
+ # scoped.
+ self._validate_asset_collection(export_dir, graph_proto.collection_def,
+ "foo.txt", "content_foo",
+ "asset_file_tensor:0")
+
+ # The constant tensor should be scoped, but its contents should be
+ # unchanged.
+ self.assertEqual(
+ compat.as_bytes("constant value"),
+ ops.get_default_graph().get_tensor_by_name(
+ "scope_name/constant_tensor_name:0").eval())
+
def testClearDevices(self):
export_dir = self._get_export_dir("test_clear_devices")
builder = saved_model_builder.SavedModelBuilder(export_dir)