aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar Karmel Allison <karmel@google.com>2018-07-25 11:33:12 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 11:37:24 -0700
commit96c76b296768852dac94aaf006beab2e637cbbb6 (patch)
tree94c22b8cf2f6098f975b327996b211bd352bb11c /tensorflow/python/saved_model
parent972920262107f5900abd79611e9432ddc6cd810b (diff)
The SavedModel legacy_init_op and main_op are functionally equivalent. Here, we remove duplicated code paths by mapping legacy_init_op into main_op in the SavedModelBuilder, and we deprecate the legacy_init_op arg. Note that the loader will still look for both, so old SavedModels will still load without trouble.
PiperOrigin-RevId: 206026743
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/builder_impl.py76
-rw-r--r--tensorflow/python/saved_model/loader_impl.py42
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py28
3 files changed, 71 insertions, 75 deletions
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index e58be804c2..8c985a7c2f 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -34,6 +34,7 @@ from tensorflow.python.platform import tf_logging
from tensorflow.python.saved_model import constants
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
+from tensorflow.python.util.deprecation import deprecated_args
from tensorflow.python.util.tf_export import tf_export
@@ -133,39 +134,32 @@ class SavedModelBuilder(object):
tf_logging.info("Assets written to: %s",
compat.as_text(assets_destination_dir))
- def _maybe_add_legacy_init_op(self, legacy_init_op=None):
- """Add legacy init op to the SavedModel.
+ def _maybe_add_main_op(self, main_op):
+ """Adds main op to the SavedModel.
Args:
- legacy_init_op: Optional legacy init op to support backward compatibility.
+ main_op: Main op to run as part of graph initialization. If None, no
+ main op will be added to the graph.
Raises:
- TypeError if legacy init op is not of type `Operation`.
- AssertionError if the graph already contains one or more legacy init ops.
+ TypeError: if main op is provided but is not of type `Operation`.
+ ValueError: if the Graph already contains an init op.
"""
- if legacy_init_op is not None:
- if not isinstance(legacy_init_op, ops.Operation):
- raise TypeError("legacy_init_op needs to be an Operation: %r" %
- legacy_init_op)
- if ops.get_collection(constants.LEGACY_INIT_OP_KEY):
- raise AssertionError(
- "graph already contains one or more legacy init ops under the "
- "collection {}.".format(constants.LEGACY_INIT_OP_KEY))
- ops.add_to_collection(constants.LEGACY_INIT_OP_KEY, legacy_init_op)
-
- def _add_main_op(self, main_op):
- """Add main op to the SavedModel.
+ if main_op is None:
+ return
- Args:
- main_op: Main op to run as part of graph initialization.
+ if not isinstance(main_op, ops.Operation):
+ raise TypeError("main_op needs to be an Operation: %r" % main_op)
- Raises:
- TypeError if main op is not of type `Operation`.
- """
- if main_op is not None:
- if not isinstance(main_op, ops.Operation):
- raise TypeError("main_op needs to be an Operation: %r" % main_op)
- ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
+ # Validate that no other init ops have been added to this graph already.
+ # We check main_op and legacy_init_op for thoroughness and explicitness.
+ for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
+ if ops.get_collection(init_op_key):
+ raise ValueError(
+ "Graph already contains one or more main ops under the "
+ "collection {}.".format(init_op_key))
+
+ ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
def _add_train_op(self, train_op):
"""Add train op to the SavedModel.
@@ -257,16 +251,12 @@ class SavedModelBuilder(object):
self._validate_tensor_info(outputs[outputs_key])
def _add_collections(
- self, assets_collection, legacy_init_op, main_op, train_op):
+ self, assets_collection, main_op, train_op):
"""Add asset and op collections to be saved."""
# Save asset files and write them to disk, if any.
self._save_and_write_assets(assets_collection)
- if main_op is None:
- # Add legacy init op to the SavedModel.
- self._maybe_add_legacy_init_op(legacy_init_op)
- else:
- self._add_main_op(main_op)
+ self._maybe_add_main_op(main_op)
self._add_train_op(train_op)
@@ -282,6 +272,9 @@ class SavedModelBuilder(object):
allow_empty=True)
return saver
+ @deprecated_args(None,
+ "Pass your op to the equivalent parameter main_op instead.",
+ "legacy_init_op")
def add_meta_graph(self,
tags,
signature_def_map=None,
@@ -306,7 +299,7 @@ class SavedModelBuilder(object):
that this collection should be a subset of the assets saved as part of
the first meta graph in the SavedModel.
legacy_init_op: Legacy support for op or group of ops to execute after the
- restore op upon a load.
+ restore op upon a load. Deprecated; please use main_op instead.
clear_devices: Set to true if the device info on the default graph should
be cleared.
main_op: Op or group of ops to execute when the graph is loaded. Note
@@ -333,8 +326,12 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
+ # legacy_init_op is deprecated, and going away in TF 2.0.
+ # Re-mapping to main_op, as treatment is identical regardless.
+ main_op = main_op or legacy_init_op
+
# Add assets and ops
- self._add_collections(assets_collection, legacy_init_op, main_op, None)
+ self._add_collections(assets_collection, main_op, None)
saver = self._maybe_create_saver(saver)
@@ -351,6 +348,9 @@ class SavedModelBuilder(object):
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
+ @deprecated_args(None,
+ "Pass your op to the equivalent parameter main_op instead.",
+ "legacy_init_op")
def add_meta_graph_and_variables(self,
sess,
tags,
@@ -378,7 +378,7 @@ class SavedModelBuilder(object):
def.
assets_collection: Assets collection to be saved with SavedModel.
legacy_init_op: Legacy support for op or group of ops to execute after the
- restore op upon a load.
+ restore op upon a load. Deprecated; please use main_op instead.
clear_devices: Set to true if the device info on the default graph should
be cleared.
main_op: Op or group of ops to execute when the graph is loaded. Note
@@ -402,8 +402,12 @@ class SavedModelBuilder(object):
# properly populated.
self._validate_signature_def_map(signature_def_map)
+ # legacy_init_op is deprecated, and going away in TF 2.0.
+ # Re-mapping to main_op, as treatment is identical regardless.
+ main_op = main_op or legacy_init_op
+
# Add assets and ops
- self._add_collections(assets_collection, legacy_init_op, main_op, None)
+ self._add_collections(assets_collection, main_op, None)
# Create the variables sub-directory, if it does not exist.
variables_dir = os.path.join(
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index e5f649fdab..fb70c91c29 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -116,11 +116,14 @@ def _get_asset_tensors(export_dir, meta_graph_def_to_load, import_scope=None):
return asset_tensor_dict
-def _get_main_op_tensor(meta_graph_def_to_load):
+def _get_main_op_tensor(
+ meta_graph_def_to_load, init_op_key=constants.MAIN_OP_KEY):
"""Gets the main op tensor, if one exists.
Args:
meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
+ init_op_key: name of collection to check; should be one of MAIN_OP_KEY
+ or the deprecated LEGACY_INIT_OP_KEY
Returns:
The main op tensor, if it exists and `None` otherwise.
@@ -131,38 +134,15 @@ def _get_main_op_tensor(meta_graph_def_to_load):
"""
collection_def = meta_graph_def_to_load.collection_def
main_op_tensor = None
- if constants.MAIN_OP_KEY in collection_def:
- main_ops = collection_def[constants.MAIN_OP_KEY].node_list.value
+ if init_op_key in collection_def:
+ main_ops = collection_def[init_op_key].node_list.value
if len(main_ops) != 1:
- raise RuntimeError("Expected exactly one SavedModel main op.")
- main_op_tensor = ops.get_collection(constants.MAIN_OP_KEY)[0]
+ raise RuntimeError("Expected exactly one SavedModel main op. "
+ "Found: {}".format(main_ops))
+ main_op_tensor = ops.get_collection(init_op_key)[0]
return main_op_tensor
-def _get_legacy_init_op_tensor(meta_graph_def_to_load):
- """Gets the legacy init op tensor, if one exists.
-
- Args:
- meta_graph_def_to_load: The meta graph def from the SavedModel to be loaded.
-
- Returns:
- The legacy init op tensor, if it exists and `None` otherwise.
-
- Raises:
- RuntimeError: If the collection def corresponding to the legacy init op key
- has other than exactly one tensor.
- """
- collection_def = meta_graph_def_to_load.collection_def
- legacy_init_op_tensor = None
- if constants.LEGACY_INIT_OP_KEY in collection_def:
- legacy_init_ops = collection_def[
- constants.LEGACY_INIT_OP_KEY].node_list.value
- if len(legacy_init_ops) != 1:
- raise RuntimeError("Expected exactly one legacy serving init op.")
- legacy_init_op_tensor = ops.get_collection(constants.LEGACY_INIT_OP_KEY)[0]
- return legacy_init_op_tensor
-
-
@tf_export("saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
@@ -340,8 +320,8 @@ class SavedModelLoader(object):
self._export_dir, meta_graph_def, import_scope=import_scope)
main_op_tensor = (
- _get_main_op_tensor(meta_graph_def) or
- (_get_legacy_init_op_tensor(meta_graph_def)))
+ _get_main_op_tensor(meta_graph_def, constants.MAIN_OP_KEY) or
+ _get_main_op_tensor(meta_graph_def, constants.LEGACY_INIT_OP_KEY))
if main_op_tensor is not None:
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index fb4732aca2..00b669fc97 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -846,9 +846,19 @@ class SavedModelTest(test.TestCase):
def testLegacyInitOpWithNonEmptyCollection(self):
export_dir = self._get_export_dir(
"test_legacy_init_op_with_non_empty_collection")
+ self._testInitOpsWithNonEmptyCollection(
+ export_dir, constants.LEGACY_INIT_OP_KEY)
+
+ def testMainOpWithNonEmptyCollection(self):
+ export_dir = self._get_export_dir(
+ "test_main_op_with_non_empty_collection")
+ self._testInitOpsWithNonEmptyCollection(export_dir, constants.MAIN_OP_KEY)
+
+ def _testInitOpsWithNonEmptyCollection(self, export_dir, key):
builder = saved_model_builder.SavedModelBuilder(export_dir)
- with self.test_session(graph=ops.Graph()) as sess:
+ g = ops.Graph()
+ with self.test_session(graph=g) as sess:
# Initialize variable `v1` to 1.
v1 = variables.Variable(1, name="v1")
ops.add_to_collection("v", v1)
@@ -857,19 +867,21 @@ class SavedModelTest(test.TestCase):
v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
- # Set up an assignment op to be run as part of the legacy_init_op.
+ # Set up an assignment op to be run as part of the init op.
assign_v2 = state_ops.assign(v2, v1)
- legacy_init_op = control_flow_ops.group(assign_v2, name="legacy_init_op")
+ init_op = control_flow_ops.group(assign_v2, name="init_op")
sess.run(variables.global_variables_initializer())
- ops.add_to_collection(constants.LEGACY_INIT_OP_KEY,
- control_flow_ops.no_op())
- # AssertionError should be raised since the LEGACY_INIT_OP_KEY collection
+ ops.add_to_collection(key, control_flow_ops.no_op())
+ # ValueError should be raised since the LEGACY_INIT_OP_KEY collection
# is not empty and we don't support multiple init ops.
- with self.assertRaises(AssertionError):
+ with self.assertRaisesRegexp(ValueError, "Graph already contains"):
builder.add_meta_graph_and_variables(
- sess, ["foo"], legacy_init_op=legacy_init_op)
+ sess, ["foo"], legacy_init_op=init_op)
+ # We shouldn't be able to add as MAIN_OP, either.
+ with self.assertRaisesRegexp(ValueError, "Graph already contains"):
+ builder.add_meta_graph_and_variables(sess, ["foo"], main_op=init_op)
def testTrainOp(self):
export_dir = self._get_export_dir("test_train_op")