diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-12 11:26:38 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-12 11:28:58 -0700 |
commit | dc7821ccf42ada3f85ca1c6e8228f0a42e61b93c (patch) | |
tree | 411f30d2cd5037352f593e551a2205a1629abee6 /tensorflow/python/saved_model | |
parent | ba9422a8adba18fc97cc1923002b7db8ca63dcfe (diff) |
Apply import_scope to asset and variable tensors during tf.saved_model.loader.load
This change explicitly declares import_scope as a kwarg for tf.saved_model.loader.load. Previously, tf.saved_model.loader.load implicitly accepted import_scope and passed it through to import_meta_graph through **saver_kwargs.
PiperOrigin-RevId: 200249417
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 22 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 53 |
2 files changed, 69 insertions, 6 deletions
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index bebf1d5e0d..d1bd8d47ae 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -79,12 +79,14 @@ def _parse_saved_model(export_dir): constants.SAVED_MODEL_FILENAME_PB)) -def _get_asset_tensors(export_dir, meta_graph_def_to_load): +def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None): """Gets the asset tensors, if defined in the meta graph def to load. Args: export_dir: Directory where the SavedModel is located. meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded. + import_scope: Optional `string` -- if specified, prepend this followed by + '/' to all returned asset tensor names. Returns: A dictionary of asset tensors, keyed by the name of the asset tensor. The @@ -104,7 +106,10 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load): for asset_any_proto in assets_any_proto: asset_proto = meta_graph_pb2.AssetFileDef() asset_any_proto.Unpack(asset_proto) - asset_tensor_dict[asset_proto.tensor_info.name] = os.path.join( + tensor_name = asset_proto.tensor_info.name + if import_scope: + tensor_name = "%s/%s" % (import_scope, tensor_name) + asset_tensor_dict[tensor_name] = os.path.join( compat.as_bytes(assets_directory), compat.as_bytes(asset_proto.filename)) return asset_tensor_dict @@ -179,7 +184,7 @@ def maybe_saved_model_directory(export_dir): @tf_export("saved_model.loader.load") -def load(sess, tags, export_dir, **saver_kwargs): +def load(sess, tags, export_dir, import_scope=None, **saver_kwargs): """Loads the model from a SavedModel as specified by tags. Args: @@ -189,6 +194,10 @@ def load(sess, tags, export_dir, **saver_kwargs): SavedModel `save()` API. export_dir: Directory in which the SavedModel protocol buffer and variables to be loaded are located. + import_scope: Optional `string` -- if specified, prepend this string + followed by '/' to all loaded tensor names. This scope is applied to + tensor instances loaded into the passed session, but it is *not* written + through to the static `MetaGraphDef` protocol buffer that is returned. **saver_kwargs: Optional keyword arguments passed through to Saver. Returns: @@ -216,7 +225,8 @@ def load(sess, tags, export_dir, **saver_kwargs): ) # Build a saver by importing the meta graph def to load. - saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs) + saver = tf_saver.import_meta_graph( + meta_graph_def_to_load, import_scope=import_scope, **saver_kwargs) if saver: # Build the checkpoint path where the variables are located. @@ -232,8 +242,8 @@ def load(sess, tags, export_dir, **saver_kwargs): "checkpoints were restored.") # Get asset tensors, if any. - asset_tensors_dictionary = _get_asset_tensors(export_dir, - meta_graph_def_to_load) + asset_tensors_dictionary = _get_asset_tensors( + export_dir, meta_graph_def_to_load, import_scope=import_scope) main_op_tensor = ( _get_main_op_tensor(meta_graph_def_to_load) or diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index effb38283b..fb4732aca2 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -1197,6 +1197,59 @@ class SavedModelTest(test.TestCase): _validate_custom_saver("tag_1", "save_1/restore_all") _validate_custom_saver("tag_2", "save_2/restore_all") + def testImportScope(self): + export_dir = self._get_export_dir("test_scoped_assets") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Build a SavedModel with a variable, an asset, and a constant tensor. + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + asset_collection = self._build_asset_collection("foo.txt", "content_foo", + "asset_file_tensor") + constant_op.constant("constant value", name="constant_tensor_name") + builder.add_meta_graph_and_variables( + sess, ["tag_name"], assets_collection=asset_collection) + + # Save the asset file path for later comparison. + asset_file_path = asset_collection[0].eval() + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + # Restore the SavedModel under an import_scope in a new graph/session. + graph_proto = loader.load( + sess, ["tag_name"], export_dir, import_scope="scope_name") + + # The loaded variable tensor should be scoped, but its contents should be + # unchanged. + self.assertEqual( + "scope_name/v:0", + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].name) + self.assertEqual( + 42, + ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + + # The loaded asset tensor should be scoped, but the asset file path and + # contents should be unchanged. + asset_collection = ops.get_collection(ops.GraphKeys.ASSET_FILEPATHS) + self.assertEqual(1, len(asset_collection)) + self.assertEqual(asset_file_path, asset_collection[0].eval()) + self.assertEqual("scope_name/asset_file_tensor:0", + asset_collection[0].name) + # The static asset data inside graph_proto.collection_def should not be + # scoped. + self._validate_asset_collection(export_dir, graph_proto.collection_def, + "foo.txt", "content_foo", + "asset_file_tensor:0") + + # The constant tensor should be scoped, but its contents should be + # unchanged. + self.assertEqual( + compat.as_bytes("constant value"), + ops.get_default_graph().get_tensor_by_name( + "scope_name/constant_tensor_name:0").eval()) + def testClearDevices(self): export_dir = self._get_export_dir("test_clear_devices") builder = saved_model_builder.SavedModelBuilder(export_dir) |