diff options
author | 2018-05-23 20:53:15 -0700 | |
---|---|---|
committer | 2018-05-23 20:56:01 -0700 | |
commit | 81ef70a0bc22163d34f1e0425122d6a93bf02eac (patch) | |
tree | c54ca4015958c17fc45573f1a553ff8d7907f451 /tensorflow/python/saved_model | |
parent | 8f863f3d71542c47390f2d40348b72296ed5c4be (diff) |
Resolve name collisions with assets in SavedModels by deduplicating names that
point to distinct files.
PiperOrigin-RevId: 197835288
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 81 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 147 |
2 files changed, 211 insertions, 17 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 071033b066..4b3982677f 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -104,10 +104,10 @@ class SavedModelBuilder(object): Args: assets_collection_to_add: The collection where the asset paths are setup. """ - asset_source_filepath_list = _maybe_save_assets(assets_collection_to_add) + asset_filename_map = _maybe_save_assets(assets_collection_to_add) # Return if there are no assets to write. - if len(asset_source_filepath_list) is 0: + if not asset_filename_map: tf_logging.info("No assets to write.") return @@ -119,12 +119,10 @@ class SavedModelBuilder(object): file_io.recursive_create_dir(assets_destination_dir) # Copy each asset from source path to destination path. - for asset_source_filepath in asset_source_filepath_list: - asset_source_filename = os.path.basename(asset_source_filepath) - + for asset_basename, asset_source_filepath in asset_filename_map.items(): asset_destination_filepath = os.path.join( compat.as_bytes(assets_destination_dir), - compat.as_bytes(asset_source_filename)) + compat.as_bytes(asset_basename)) # Only copy the asset file to the destination if it does not already # exist. This is to ensure that an asset with the same name defined as @@ -475,16 +473,17 @@ def _maybe_save_assets(assets_collection_to_add=None): assets_collection_to_add: The collection where the asset paths are setup. Returns: - The list of filepaths to the assets in the assets collection. + A dict of asset basenames for saving to the original full path to the asset. Raises: ValueError: Indicating an invalid filepath tensor. """ - asset_source_filepath_list = [] + # Map of target file names to original filenames + asset_filename_map = {} if assets_collection_to_add is None: tf_logging.info("No assets to save.") - return asset_source_filepath_list + return asset_filename_map # Iterate over the supplied asset collection, build the `AssetFile` proto # and add them to the collection with key `constants.ASSETS_KEY`, in the @@ -494,15 +493,71 @@ def _maybe_save_assets(assets_collection_to_add=None): if not asset_source_filepath: raise ValueError("Invalid asset filepath tensor %s" % asset_tensor) - asset_source_filename = os.path.basename(asset_source_filepath) + asset_filename = _get_asset_filename_to_add( + asset_source_filepath, asset_filename_map) # Build `AssetFile` proto and add it to the asset collection in the graph. - _add_asset_to_collection(asset_source_filename, asset_tensor) + # Note that this should be done even when the file is a duplicate of an + # already-added file, as the tensor reference should still exist. + _add_asset_to_collection(asset_filename, asset_tensor) - asset_source_filepath_list.append(asset_source_filepath) + # In the cases where we are adding a duplicate, this will result in the + # last of the filepaths being the one used for copying the file to the + # SavedModel. Since the files in question are the same, it doesn't matter + # either way. + asset_filename_map[asset_filename] = asset_source_filepath tf_logging.info("Assets added to graph.") - return asset_source_filepath_list + return asset_filename_map + + +def _get_asset_filename_to_add(asset_filepath, asset_filename_map): + """Get a unique basename to add to the SavedModel if this file is unseen. + + Assets come from users as full paths, and we save them out to the + SavedModel as basenames. In some cases, the basenames collide. Here, + we dedupe asset basenames by first checking if the file is the same, + and, if different, generate and return an index-suffixed basename + that can be used to add the asset to the SavedModel. + + Args: + asset_filepath: the full path to the asset that is being saved + asset_filename_map: a dict of filenames used for saving the asset in + the SavedModel to full paths from which the filenames were derived. + + Returns: + Uniquified filename string if the file is not a duplicate, or the original + filename if the file has already been seen and saved. + """ + asset_filename = os.path.basename(asset_filepath) + + if asset_filename not in asset_filename_map: + # This is an unseen asset. Safe to add. + return asset_filename + + other_asset_filepath = asset_filename_map[asset_filename] + if other_asset_filepath == asset_filepath: + # This is the same file, stored twice in the collection list. No need + # to make unique. + return asset_filename + + # Else, asset_filename is in the map, and the filepath is different. Dedupe. + if not file_io.filecmp(asset_filepath, other_asset_filepath): + # Files are different; dedupe filenames. + return _get_unique_asset_filename(asset_filename, asset_filename_map) + + # Files are the same; don't make unique. + return asset_filename + + +def _get_unique_asset_filename(asset_filename, asset_filename_map): + i = 1 + unique_filename = asset_filename + while unique_filename in asset_filename_map: + unique_filename = compat.as_bytes("_").join( + [compat.as_bytes(asset_filename), compat.as_bytes(str(i))]) + i += 1 + return unique_filename def _asset_path_from_tensor(path_tensor): diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 1b83d60df9..7302c77ad5 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -64,9 +64,12 @@ class SavedModelTest(test.TestCase): self.assertEqual(variable_value, v.eval()) def _build_asset_collection(self, asset_file_name, asset_file_contents, - asset_file_tensor_name): + asset_file_tensor_name, asset_subdir=""): + parent_dir = os.path.join( + compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_subdir)) + file_io.recursive_create_dir(parent_dir) asset_filepath = os.path.join( - compat.as_bytes(test.get_temp_dir()), compat.as_bytes(asset_file_name)) + compat.as_bytes(parent_dir), compat.as_bytes(asset_file_name)) file_io.write_string_to_file(asset_filepath, asset_file_contents) asset_file_tensor = constant_op.constant( asset_filepath, name=asset_file_tensor_name) @@ -77,10 +80,11 @@ class SavedModelTest(test.TestCase): def _validate_asset_collection(self, export_dir, graph_collection_def, expected_asset_file_name, expected_asset_file_contents, - expected_asset_tensor_name): + expected_asset_tensor_name, + asset_id=0): assets_any = graph_collection_def[constants.ASSETS_KEY].any_list.value asset = meta_graph_pb2.AssetFileDef() - assets_any[0].Unpack(asset) + assets_any[asset_id].Unpack(asset) assets_path = os.path.join( compat.as_bytes(export_dir), compat.as_bytes(constants.ASSETS_DIRECTORY), @@ -634,6 +638,141 @@ class SavedModelTest(test.TestCase): compat.as_bytes("ignored.txt")) self.assertFalse(file_io.file_exists(ignored_asset_path)) + def testAssetsNameCollisionDiffFile(self): + export_dir = self._get_export_dir("test_assets_name_collision_diff_file") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar bak", "asset_file_tensor", + asset_subdir="1") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1", + asset_subdir="2") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar bak", + "asset_file_tensor:0") + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt_1", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + + def testAssetsNameCollisionSameFilepath(self): + export_dir = self._get_export_dir("test_assets_name_collision_same_path") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor:0") + # The second tensor should be recorded, but the same. + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + ignored_asset_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes("hello42.txt_1")) + self.assertFalse(file_io.file_exists(ignored_asset_path)) + + def testAssetsNameCollisionSameFile(self): + export_dir = self._get_export_dir("test_assets_name_collision_same_file") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor", + asset_subdir="1") + + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz", "asset_file_tensor_1", + asset_subdir="2") + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor:0") + # The second tensor should be recorded, but the same. + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz", + "asset_file_tensor_1:0", + asset_id=1) + ignored_asset_path = os.path.join( + compat.as_bytes(export_dir), + compat.as_bytes(constants.ASSETS_DIRECTORY), + compat.as_bytes("hello42.txt_1")) + self.assertFalse(file_io.file_exists(ignored_asset_path)) + + def testAssetsNameCollisionManyFiles(self): + export_dir = self._get_export_dir("test_assets_name_collision_many_files") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + with self.test_session(graph=ops.Graph()) as sess: + self._init_and_validate_variable(sess, "v", 42) + + for i in range(5): + idx = str(i) + asset_collection = self._build_asset_collection( + "hello42.txt", "foo bar baz " + idx, "asset_file_tensor_" + idx, + asset_subdir=idx) + + builder.add_meta_graph_and_variables( + sess, ["foo"], assets_collection=asset_collection) + + # Save the SavedModel to disk. + builder.save() + + with self.test_session(graph=ops.Graph()) as sess: + foo_graph = loader.load(sess, ["foo"], export_dir) + for i in range(1, 5): + idx = str(i) + self._validate_asset_collection( + export_dir, foo_graph.collection_def, "hello42.txt_" + idx, + "foo bar baz " + idx, "asset_file_tensor_{}:0".format(idx), + asset_id=i) + + self._validate_asset_collection(export_dir, foo_graph.collection_def, + "hello42.txt", "foo bar baz 0", + "asset_file_tensor_0:0") + def testCustomMainOp(self): export_dir = self._get_export_dir("test_main_op") builder = saved_model_builder.SavedModelBuilder(export_dir) |