From 689fecc0712944da54230eb0aa7bc733e918afd1 Mon Sep 17 00:00:00 2001 From: Katherine Wu Date: Thu, 9 Aug 2018 12:39:40 -0700 Subject: Refactor out path concatenation in the SavedModelBuilder. PiperOrigin-RevId: 208094190 --- tensorflow/python/estimator/estimator.py | 22 ++++++------- tensorflow/python/saved_model/BUILD | 4 +++ tensorflow/python/saved_model/builder_impl.py | 21 +++--------- tensorflow/python/saved_model/loader_impl.py | 6 ++-- tensorflow/python/saved_model/utils_impl.py | 47 +++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 31 deletions(-) diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py index 16feda8c83..d6e407b958 100644 --- a/tensorflow/python/estimator/estimator.py +++ b/tensorflow/python/estimator/estimator.py @@ -50,7 +50,7 @@ from tensorflow.python.ops import variables from tensorflow.python.platform import gfile from tensorflow.python.platform import tf_logging as logging from tensorflow.python.saved_model import builder as saved_model_builder -from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.summary import summary from tensorflow.python.summary.writer import writer_cache from tensorflow.python.training import checkpoint_management @@ -2019,14 +2019,11 @@ class WarmStartSettings( def _get_saved_model_ckpt(saved_model_dir): """Return path to variables checkpoint in a SavedModel directory.""" if not gfile.Exists( - os.path.join(compat.as_bytes(saved_model_dir), - compat.as_bytes('variables/variables.index'))): + os.path.join(saved_model_utils.get_variables_dir(saved_model_dir), + compat.as_text('variables.index'))): raise ValueError('Directory provided has an invalid SavedModel format: %s' % saved_model_dir) - return os.path.join( - compat.as_bytes(saved_model_dir), - compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY, - constants.VARIABLES_FILENAME))) + return saved_model_utils.get_variables_path(saved_model_dir) def _get_default_warm_start_settings(warm_start_from): @@ -2048,12 +2045,15 @@ def _get_default_warm_start_settings(warm_start_from): if isinstance(warm_start_from, (six.string_types, six.binary_type)): # Infer that this is a SavedModel if export_path + # 'variables/variables.index' exists, and if so, construct the - # WarmStartSettings pointing to export_path + 'variables/variables'. - if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from), - compat.as_bytes('variables/variables.index'))): + # WarmStartSettings pointing to the variables path + # (export_path + 'variables/variables'). + if gfile.Exists(os.path.join( + saved_model_utils.get_variables_dir(warm_start_from), + compat.as_text('variables.index'))): logging.info('Warm-starting from a SavedModel') return WarmStartSettings( - ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from)) + ckpt_to_initialize_from=saved_model_utils.get_variables_path( + warm_start_from)) return WarmStartSettings(ckpt_to_initialize_from=warm_start_from) elif isinstance(warm_start_from, WarmStartSettings): return warm_start_from diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index 076f2d8760..7a37eda5ea 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -62,6 +62,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -81,6 +82,7 @@ py_library( srcs_version = "PY2AND3", deps = [ ":constants", + ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", @@ -187,8 +189,10 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":constants", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:lib", "//tensorflow/python:sparse_tensor", "//tensorflow/python:util", ], diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 8c985a7c2f..8e7f123a85 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -32,6 +32,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.deprecation import deprecated_args @@ -112,12 +113,8 @@ class SavedModelBuilder(object): tf_logging.info("No assets to write.") return - assets_destination_dir = os.path.join( - compat.as_bytes(self._export_dir), - compat.as_bytes(constants.ASSETS_DIRECTORY)) - - if not file_io.file_exists(assets_destination_dir): - file_io.recursive_create_dir(assets_destination_dir) + assets_destination_dir = saved_model_utils.get_or_create_assets_dir( + self._export_dir) # Copy each asset from source path to destination path. for asset_basename, asset_source_filepath in asset_filename_map.items(): @@ -409,16 +406,8 @@ class SavedModelBuilder(object): # Add assets and ops self._add_collections(assets_collection, main_op, None) - # Create the variables sub-directory, if it does not exist. - variables_dir = os.path.join( - compat.as_text(self._export_dir), - compat.as_text(constants.VARIABLES_DIRECTORY)) - if not file_io.file_exists(variables_dir): - file_io.recursive_create_dir(variables_dir) - - variables_path = os.path.join( - compat.as_text(variables_dir), - compat.as_text(constants.VARIABLES_FILENAME)) + saved_model_utils.get_or_create_variables_dir(self._export_dir) + variables_path = saved_model_utils.get_variables_path(self._export_dir) saver = self._maybe_create_saver(saver) diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py index 16077f52fa..e8536108e8 100644 --- a/tensorflow/python/saved_model/loader_impl.py +++ b/tensorflow/python/saved_model/loader_impl.py @@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging from tensorflow.python.saved_model import constants +from tensorflow.python.saved_model import utils_impl as saved_model_utils from tensorflow.python.training import saver as tf_saver from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -203,10 +204,7 @@ class SavedModelLoader(object): variables to be loaded are located. """ self._export_dir = export_dir - self._variables_path = os.path.join( - compat.as_bytes(export_dir), - compat.as_bytes(constants.VARIABLES_DIRECTORY), - compat.as_bytes(constants.VARIABLES_FILENAME)) + self._variables_path = saved_model_utils.get_variables_path(export_dir) self._saved_model = _parse_saved_model(export_dir) @property diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py index cddce29a08..20ff34fd8e 100644 --- a/tensorflow/python/saved_model/utils_impl.py +++ b/tensorflow/python/saved_model/utils_impl.py @@ -18,10 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import os + from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.lib.io import file_io +from tensorflow.python.saved_model import constants +from tensorflow.python.util import compat from tensorflow.python.util.tf_export import tf_export @@ -84,3 +89,45 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None): _get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name)) else: raise ValueError("Invalid TensorInfo.encoding: %s" % encoding) + + +# Path helpers. + + +def get_or_create_variables_dir(export_dir): + """Return variables sub-directory, or create one if it doesn't exist.""" + variables_dir = get_variables_dir(export_dir) + if not file_io.file_exists(variables_dir): + file_io.recursive_create_dir(variables_dir) + return variables_dir + + +def get_variables_dir(export_dir): + """Return variables sub-directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.VARIABLES_DIRECTORY)) + + +def get_variables_path(export_dir): + """Return the variables path, used as the prefix for checkpoint files.""" + return os.path.join( + compat.as_text(get_variables_dir(export_dir)), + compat.as_text(constants.VARIABLES_FILENAME)) + + +def get_or_create_assets_dir(export_dir): + """Return assets sub-directory, or create one if it doesn't exist.""" + assets_destination_dir = get_assets_dir(export_dir) + + if not file_io.file_exists(assets_destination_dir): + file_io.recursive_create_dir(assets_destination_dir) + + return assets_destination_dir + + +def get_assets_dir(export_dir): + """Return path to asset directory in the SavedModel.""" + return os.path.join( + compat.as_text(export_dir), + compat.as_text(constants.ASSETS_DIRECTORY)) -- cgit v1.2.3