diff options
author | Katherine Wu <kathywu@google.com> | 2018-08-09 12:39:40 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-09 12:46:57 -0700 |
commit | 689fecc0712944da54230eb0aa7bc733e918afd1 (patch) | |
tree | 37dbbcea7b8053309899b95fa433acbd58895559 /tensorflow/python/saved_model | |
parent | e3963846f14b40d82f0adc2c4c8a79d91863d2ed (diff) |
Refactor out path concatenation in the SavedModelBuilder.
PiperOrigin-RevId: 208094190
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 21 | ||||
-rw-r--r-- | tensorflow/python/saved_model/loader_impl.py | 6 | ||||
-rw-r--r-- | tensorflow/python/saved_model/utils_impl.py | 47 |
4 files changed, 58 insertions, 20 deletions
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 076f2d8760..7a37eda5ea 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -62,6 +62,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -81,6 +82,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -187,8 +189,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":constants", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lib", "//tensorflow/python:sparse_tensor", "//tensorflow/python:util", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 8c985a7c2f..8e7f123a85 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args @@ -112,12 +113,8 @@ class SavedModelBuilder(object): tf_logging.info("No assets to write.") return - assets_destination_dir = os.path.join( - compat.as_bytes(self._export_dir), - compat.as_bytes(constants.ASSETS_DIRECTORY)) - - if not file_io.file_exists(assets_destination_dir): - file_io.recursive_create_dir(assets_destination_dir) + assets_destination_dir = saved_model_utils.get_or_create_assets_dir( + self._export_dir) # Copy each asset from source path to destination path. for asset_basename, asset_source_filepath in asset_filename_map.items(): @@ -409,16 +406,8 @@ class SavedModelBuilder(object): # Add assets and ops self._add_collections(assets_collection, main_op, None) - # Create the variables sub-directory, if it does not exist. - variables_dir = os.path.join( - compat.as_text(self._export_dir), - compat.as_text(constants.VARIABLES_DIRECTORY)) - if not file_io.file_exists(variables_dir): - file_io.recursive_create_dir(variables_dir) - - variables_path = os.path.join( - compat.as_text(variables_dir), - compat.as_text(constants.VARIABLES_FILENAME)) + saved_model_utils.get_or_create_variables_dir(self._export_dir) + variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 16077f52fa..e8536108e8 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -203,10 +204,7 @@ class SavedModelLoader(object): variables to be loaded are located. """ self._export_dir = export_dir - self._variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) + self._variables_path = saved_model_utils.get_variables_path(export_dir) self._saved_model = _parse_saved_model(export_dir) @property diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index cddce29a08..20ff34fd8e 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -18,10 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -84,3 +89,45 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) + + +# Path helpers. + + +def get_or_create_variables_dir(export_dir): + """Return variables sub-directory, or create one if it doesn't exist.""" + variables_dir = get_variables_dir(export_dir) + if not file_io.file_exists(variables_dir): + file_io.recursive_create_dir(variables_dir) + return variables_dir + + +def get_variables_dir(export_dir): + """Return variables sub-directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.VARIABLES_DIRECTORY)) + + +def get_variables_path(export_dir): + """Return the variables path, used as the prefix for checkpoint files.""" + return os.path.join( + compat.as_text(get_variables_dir(export_dir)), + compat.as_text(constants.VARIABLES_FILENAME)) + + +def get_or_create_assets_dir(export_dir): + """Return assets sub-directory, or create one if it doesn't exist.""" + assets_destination_dir = get_assets_dir(export_dir) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + return assets_destination_dir + + +def get_assets_dir(export_dir): + """Return path to asset directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.ASSETS_DIRECTORY)) |