aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Katherine Wu <kathywu@google.com>2018-08-09 12:39:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-09 12:46:57 -0700
commit689fecc0712944da54230eb0aa7bc733e918afd1 (patch)
tree37dbbcea7b8053309899b95fa433acbd58895559
parente3963846f14b40d82f0adc2c4c8a79d91863d2ed (diff)
Refactor out path concatenation in the SavedModelBuilder.
PiperOrigin-RevId: 208094190
-rw-r--r--tensorflow/python/estimator/estimator.py22
-rw-r--r--tensorflow/python/saved_model/BUILD4
-rw-r--r--tensorflow/python/saved_model/builder_impl.py21
-rw-r--r--tensorflow/python/saved_model/loader_impl.py6
-rw-r--r--tensorflow/python/saved_model/utils_impl.py47
5 files changed, 69 insertions, 31 deletions
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 16feda8c83..d6e407b958 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -50,7 +50,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import builder as saved_model_builder
-from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.summary import summary
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import checkpoint_management
@@ -2019,14 +2019,11 @@ class WarmStartSettings(
def _get_saved_model_ckpt(saved_model_dir):
"""Return path to variables checkpoint in a SavedModel directory."""
if not gfile.Exists(
- os.path.join(compat.as_bytes(saved_model_dir),
- compat.as_bytes('variables/variables.index'))):
+ os.path.join(saved_model_utils.get_variables_dir(saved_model_dir),
+ compat.as_text('variables.index'))):
raise ValueError('Directory provided has an invalid SavedModel format: %s'
% saved_model_dir)
- return os.path.join(
- compat.as_bytes(saved_model_dir),
- compat.as_bytes('{}/{}'.format(constants.VARIABLES_DIRECTORY,
- constants.VARIABLES_FILENAME)))
+ return saved_model_utils.get_variables_path(saved_model_dir)
def _get_default_warm_start_settings(warm_start_from):
@@ -2048,12 +2045,15 @@ def _get_default_warm_start_settings(warm_start_from):
if isinstance(warm_start_from, (six.string_types, six.binary_type)):
# Infer that this is a SavedModel if export_path +
# 'variables/variables.index' exists, and if so, construct the
- # WarmStartSettings pointing to export_path + 'variables/variables'.
- if gfile.Exists(os.path.join(compat.as_bytes(warm_start_from),
- compat.as_bytes('variables/variables.index'))):
+ # WarmStartSettings pointing to the variables path
+ # (export_path + 'variables/variables').
+ if gfile.Exists(os.path.join(
+ saved_model_utils.get_variables_dir(warm_start_from),
+ compat.as_text('variables.index'))):
logging.info('Warm-starting from a SavedModel')
return WarmStartSettings(
- ckpt_to_initialize_from=_get_saved_model_ckpt(warm_start_from))
+ ckpt_to_initialize_from=saved_model_utils.get_variables_path(
+ warm_start_from))
return WarmStartSettings(ckpt_to_initialize_from=warm_start_from)
elif isinstance(warm_start_from, WarmStartSettings):
return warm_start_from
diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD
index 076f2d8760..7a37eda5ea 100644
--- a/tensorflow/python/saved_model/BUILD
+++ b/tensorflow/python/saved_model/BUILD
@@ -62,6 +62,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
+ ":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
@@ -81,6 +82,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":constants",
+ ":utils",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:lib",
@@ -187,8 +189,10 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
+ ":constants",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:lib",
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:util",
],
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 8c985a7c2f..8e7f123a85 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -32,6 +32,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated_args
@@ -112,12 +113,8 @@ class SavedModelBuilder(object):
tf_logging.info("No assets to write.")
return
- assets_destination_dir = os.path.join(
- compat.as_bytes(self._export_dir),
- compat.as_bytes(constants.ASSETS_DIRECTORY))
-
- if not file_io.file_exists(assets_destination_dir):
- file_io.recursive_create_dir(assets_destination_dir)
+ assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
+ self._export_dir)
# Copy each asset from source path to destination path.
for asset_basename, asset_source_filepath in asset_filename_map.items():
@@ -409,16 +406,8 @@ class SavedModelBuilder(object):
# Add assets and ops
self._add_collections(assets_collection, main_op, None)
- # Create the variables sub-directory, if it does not exist.
- variables_dir = os.path.join(
- compat.as_text(self._export_dir),
- compat.as_text(constants.VARIABLES_DIRECTORY))
- if not file_io.file_exists(variables_dir):
- file_io.recursive_create_dir(variables_dir)
-
- variables_path = os.path.join(
- compat.as_text(variables_dir),
- compat.as_text(constants.VARIABLES_FILENAME))
+ saved_model_utils.get_or_create_variables_dir(self._export_dir)
+ variables_path = saved_model_utils.get_variables_path(self._export_dir)
saver = self._maybe_create_saver(saver)
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index 16077f52fa..e8536108e8 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -31,6 +31,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
+from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -203,10 +204,7 @@ class SavedModelLoader(object):
variables to be loaded are located.
"""
self._export_dir = export_dir
- self._variables_path = os.path.join(
- compat.as_bytes(export_dir),
- compat.as_bytes(constants.VARIABLES_DIRECTORY),
- compat.as_bytes(constants.VARIABLES_FILENAME))
+ self._variables_path = saved_model_utils.get_variables_path(export_dir)
self._saved_model = _parse_saved_model(export_dir)
@property
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index cddce29a08..20ff34fd8e 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -18,10 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
+
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.saved_model import constants
+from tensorflow.python.util import compat
from tensorflow.python.util.tf_export import tf_export
@@ -84,3 +89,45 @@ def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
_get_tensor(tensor_info.coo_sparse.dense_shape_tensor_name))
else:
raise ValueError("Invalid TensorInfo.encoding: %s" % encoding)
+
+
+# Path helpers.
+
+
+def get_or_create_variables_dir(export_dir):
+ """Return variables sub-directory, or create one if it doesn't exist."""
+ variables_dir = get_variables_dir(export_dir)
+ if not file_io.file_exists(variables_dir):
+ file_io.recursive_create_dir(variables_dir)
+ return variables_dir
+
+
+def get_variables_dir(export_dir):
+ """Return variables sub-directory in the SavedModel."""
+ return os.path.join(
+ compat.as_text(export_dir),
+ compat.as_text(constants.VARIABLES_DIRECTORY))
+
+
+def get_variables_path(export_dir):
+ """Return the variables path, used as the prefix for checkpoint files."""
+ return os.path.join(
+ compat.as_text(get_variables_dir(export_dir)),
+ compat.as_text(constants.VARIABLES_FILENAME))
+
+
+def get_or_create_assets_dir(export_dir):
+ """Return assets sub-directory, or create one if it doesn't exist."""
+ assets_destination_dir = get_assets_dir(export_dir)
+
+ if not file_io.file_exists(assets_destination_dir):
+ file_io.recursive_create_dir(assets_destination_dir)
+
+ return assets_destination_dir
+
+
+def get_assets_dir(export_dir):
+ """Return path to asset directory in the SavedModel."""
+ return os.path.join(
+ compat.as_text(export_dir),
+ compat.as_text(constants.ASSETS_DIRECTORY))