aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-05-23 20:53:15 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-23 20:56:01 -0700
commit81ef70a0bc22163d34f1e0425122d6a93bf02eac (patch)
treec54ca4015958c17fc45573f1a553ff8d7907f451 /tensorflow/python/saved_model
parent8f863f3d71542c47390f2d40348b72296ed5c4be (diff)
Resolve name collisions with assets in SavedModels by deduplicating names that
point to distinct files. PiperOrigin-RevId: 197835288
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py81
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py147
2 files changed, 211 insertions, 17 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 071033b066..4b3982677f 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -104,10 +104,10 @@ class SavedModelBuilder(object):
Args:
assets_collection_to_add: The collection where the asset paths are setup.
"""
- asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add)
+ asset_filename_map = _maybe_save_assets(assets_collection_to_add)
# Return if there are no assets to write.
- if len(asset_source_filepath_list) is 0:
+ if not asset_filename_map:
tf_logging.info("No assets to write.")
return
@@ -119,12 +119,10 @@ class SavedModelBuilder(object):
file_io.recursive_create_dir(assets_destination_dir)
# Copy each asset from source path to destination path.
- for asset_source_filepath in asset_source_filepath_list:
- asset_source_filename = os.path.basename(asset_source_filepath)
-
+ for asset_basename, asset_source_filepath in asset_filename_map.items():
asset_destination_filepath = os.path.join(
compat.as_bytes(assets_destination_dir),
- compat.as_bytes(asset_source_filename))
+ compat.as_bytes(asset_basename))
# Only copy the asset file to the destination if it does not already
# exist. This is to ensure that an asset with the same name defined as
@@ -475,16 +473,17 @@ def _maybe_save_assets(assets_collection_to_add=None):
assets_collection_to_add: The collection where the asset paths are setup.
Returns:
- The list of filepaths to the assets in the assets collection.
+ A dict of asset basenames for saving to the original full path to the asset.
Raises:
ValueError: Indicating an invalid filepath tensor.
"""
- asset_source_filepath_list = []
+ # Map of target file names to original filenames
+ asset_filename_map = {}
if assets_collection_to_add is None:
tf_logging.info("No assets to save.")
- return asset_source_filepath_list
+ return asset_filename_map
# Iterate over the supplied asset collection, build the `AssetFile` proto
# and add them to the collection with key `constants.ASSETS_KEY`, in the
@@ -494,15 +493,71 @@ def _maybe_save_assets(assets_collection_to_add=None):
if not asset_source_filepath:
raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
- asset_source_filename = os.path.basename(asset_source_filepath)
+ asset_filename = _get_asset_filename_to_add(
+ asset_source_filepath, asset_filename_map)
# Build `AssetFile` proto and add it to the asset collection in the graph.
- _add_asset_to_collection(asset_source_filename, asset_tensor)
+ # Note that this should be done even when the file is a duplicate of an
+ # already-added file, as the tensor reference should still exist.
+ _add_asset_to_collection(asset_filename, asset_tensor)
- asset_source_filepath_list.append(asset_source_filepath)
+ # In the cases where we are adding a duplicate, this will result in the
+ # last of the filepaths being the one used for copying the file to the
+ # SavedModel. Since the files in question are the same, it doesn't matter
+ # either way.
+ asset_filename_map[asset_filename] = asset_source_filepath
tf_logging.info("Assets added to graph.")
- return asset_source_filepath_list
+ return asset_filename_map
+
+
+def _get_asset_filename_to_add(asset_filepath, asset_filename_map):
+ """Get a unique basename to add to the SavedModel if this file is unseen.
+
+ Assets come from users as full paths, and we save them out to the
+ SavedModel as basenames. In some cases, the basenames collide. Here,
+ we dedupe asset basenames by first checking if the file is the same,
+ and, if different, generate and return an index-suffixed basename
+ that can be used to add the asset to the SavedModel.
+
+ Args:
+ asset_filepath: the full path to the asset that is being saved
+ asset_filename_map: a dict of filenames used for saving the asset in
+ the SavedModel to full paths from which the filenames were derived.
+
+ Returns:
+ Uniquified filename string if the file is not a duplicate, or the original
+ filename if the file has already been seen and saved.
+ """
+ asset_filename = os.path.basename(asset_filepath)
+
+ if asset_filename not in asset_filename_map:
+ # This is an unseen asset. Safe to add.
+ return asset_filename
+
+ other_asset_filepath = asset_filename_map[asset_filename]
+ if other_asset_filepath == asset_filepath:
+ # This is the same file, stored twice in the collection list. No need
+ # to make unique.
+ return asset_filename
+
+ # Else, asset_filename is in the map, and the filepath is different. Dedupe.
+ if not file_io.filecmp(asset_filepath, other_asset_filepath):
+ # Files are different; dedupe filenames.
+ return _get_unique_asset_filename(asset_filename, asset_filename_map)
+
+ # Files are the same; don't make unique.
+ return asset_filename
+
+
+def _get_unique_asset_filename(asset_filename, asset_filename_map):
+ i = 1
+ unique_filename = asset_filename
+ while unique_filename in asset_filename_map:
+ unique_filename = compat.as_bytes("_").join(
+ [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
+ i += 1
+ return unique_filename
def _asset_path_from_tensor(path_tensor):
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 1b83d60df9..7302c77ad5 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -64,9 +64,12 @@ class SavedModelTest(test.TestCase):
self.assertEqual(variable_value, v.eval())
def _build_asset_collection(self, asset_file_name, asset_file_contents,
- asset_file_tensor_name):
+ asset_file_tensor_name, asset_subdir=""):
+ parent_dir = os.path.join(
+ compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir))
+ file_io.recursive_create_dir(parent_dir)
asset_filepath = os.path.join(
- compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name))
+ compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name))
file_io.write_string_to_file(asset_filepath, asset_file_contents)
asset_file_tensor = constant_op.constant(
asset_filepath, name=asset_file_tensor_name)
@@ -77,10 +80,11 @@ class SavedModelTest(test.TestCase):
def _validate_asset_collection(self, export_dir, graph_collection_def,
expected_asset_file_name,
expected_asset_file_contents,
- expected_asset_tensor_name):
+ expected_asset_tensor_name,
+ asset_id=0):
assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value
asset = meta_graph_pb2.AssetFileDef()
- assets_any[0].Unpack(asset)
+ assets_any[asset_id].Unpack(asset)
assets_path = os.path.join(
compat.as_bytes(export_dir),
compat.as_bytes(constants.ASSETS_DIRECTORY),
@@ -634,6 +638,141 @@ class SavedModelTest(test.TestCase):
compat.as_bytes("ignored.txt"))
self.assertFalse(file_io.file_exists(ignored_asset_path))
+ def testAssetsNameCollisionDiffFile(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_diff_file")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar bak", "asset_file_tensor",
+ asset_subdir="1")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+ asset_subdir="2")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar bak",
+ "asset_file_tensor:0")
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt_1", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+
+ def testAssetsNameCollisionSameFilepath(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_same_path")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor:0")
+ # The second tensor should be recorded, but the same.
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+ ignored_asset_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes("hello42.txt_1"))
+ self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+ def testAssetsNameCollisionSameFile(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_same_file")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor",
+ asset_subdir="1")
+
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz", "asset_file_tensor_1",
+ asset_subdir="2")
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor:0")
+ # The second tensor should be recorded, but the same.
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz",
+ "asset_file_tensor_1:0",
+ asset_id=1)
+ ignored_asset_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes(constants.ASSETS_DIRECTORY),
+ compat.as_bytes("hello42.txt_1"))
+ self.assertFalse(file_io.file_exists(ignored_asset_path))
+
+ def testAssetsNameCollisionManyFiles(self):
+ export_dir = self._get_export_dir("test_assets_name_collision_many_files")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ self._init_and_validate_variable(sess, "v", 42)
+
+ for i in range(5):
+ idx = str(i)
+ asset_collection = self._build_asset_collection(
+ "hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx,
+ asset_subdir=idx)
+
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], assets_collection=asset_collection)
+
+ # Save the SavedModel to disk.
+ builder.save()
+
+ with self.test_session(graph=ops.Graph()) as sess:
+ foo_graph = loader.load(sess, ["foo"], export_dir)
+ for i in range(1, 5):
+ idx = str(i)
+ self._validate_asset_collection(
+ export_dir, foo_graph.collection_def, "hello42.txt_" + idx,
+ "foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx),
+ asset_id=i)
+
+ self._validate_asset_collection(export_dir, foo_graph.collection_def,
+ "hello42.txt", "foo bar baz 0",
+ "asset_file_tensor_0:0")
+
def testCustomMainOp(self):
export_dir = self._get_export_dir("test_main_op")
builder = saved_model_builder.SavedModelBuilder(export_dir)