aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/saved_model
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-02 18:19:21 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-02 18:28:12 -0800
commitda1588ccf5c62eccac8013673359ac15b43eb394 (patch)
treeca08b5eee9e81362aacdd15874cee00f2fb34ca8 /tensorflow/python/saved_model
parenta8041609f182db130953e123d7be4823ea13d05e (diff)
meta_graph export: Add support to strip default valued attributes.
Following APIs now accept an additional argument (`strip_default_attrs`) to enable/disable (default:disabled) stripping of default valued attributes in a NodeDef: o meta_graph: export_meta_graph, create_meta_graph. o saver: Saver.save, Saver.export_meta_graph. o builder: SavedModelBuilder.add_meta_graph, SavedModelBuilder.add_meta_graph_and_variables. o estimator: Estimator.export_savedmodel. Related changes: o Pywrap C++ AreAttrValuesEqual to compare two AttrValue instances. This allows for a single/canonical way of comparing AttrValues in C++/Python. o Add a utility method to meta_graph.py to get the node def by name in a graph def. o Update SavedModelBuilder documentation on relevance of strip_default_attrs flag. PiperOrigin-RevId: 180619001
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r--tensorflow/python/saved_model/README.md29
-rw-r--r--tensorflow/python/saved_model/builder_impl.py22
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py131
3 files changed, 178 insertions, 4 deletions
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index 8c78013ffd..5eeaf73a43 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -117,6 +117,35 @@ with tf.Session(graph=tf.Graph()) as sess:
builder.save()
~~~
+#### Stripping Default valued attributes
+The SavedModelBuilder class allows users to control whether default-valued
+attributes must be stripped from the NodeDefs while adding a meta graph to the
+SavedModel bundle. Both `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph` methods accept a Boolean flag
+`strip_default_attrs` that controls this behavior.
+
+If `strip_default_attrs` is `False`, the exported MetaGraphDef will have the
+default valued attributes in all it's NodeDef instances. This can break forward
+compatibility with a sequence of events such as the following:
+
+* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a
+ default (`bool`) at version 101.
+* A model producer (such as a Trainer) binary picks up this change
+ (version 101) to the OpDef and re-exports an existing model that uses Op `Foo`.
+* A model consumer (such as Tensorflow Serving) running an older binary
+ (version 100) doesn't have attribute `T` for Op `Foo`, but tries to import
+ this model. The model consumer doesn't recognize attribute `T` in a NodeDef
+ that uses Op `Foo` and therefore fails to load the model.
+
+By setting `strip_default_attrs` to `True`, the model producers can strip away
+any default valued attributes in the NodeDefs. This helps ensure that newly
+added attributes with defaults don't cause older model consumers to fail loading
+models regenerated with newer training binaries.
+
+TIP: If you care about forward compatibility, then set `strip_default_attrs`
+to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph`.
+
### Loader
The SavedModel loader is implemented in C++ and Python.
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 16651ffebc..62ee53b816 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -239,7 +239,9 @@ class SavedModelBuilder(object):
assets_collection=None,
legacy_init_op=None,
clear_devices=False,
- main_op=None):
+ main_op=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel.
Creates a Saver in the current scope and uses the Saver to export the meta
@@ -260,11 +262,15 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at
load-time.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Raises:
AssertionError: If the variables for the SavedModel have not been saved
yet, or if the graph already contains one or more legacy init ops.
"""
+ # pylint: enable=line-too-long
if not self._has_saved_variables:
raise AssertionError(
"Graph state including variables and assets has not been saved yet. "
@@ -299,7 +305,8 @@ class SavedModelBuilder(object):
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
- meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+ meta_graph_def = saver.export_meta_graph(
+ clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
@@ -311,7 +318,9 @@ class SavedModelBuilder(object):
assets_collection=None,
legacy_init_op=None,
clear_devices=False,
- main_op=None):
+ main_op=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel and saves variables.
Creates a Saver to save the variables from the provided session. Exports the
@@ -334,7 +343,11 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at
load-time.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
"""
+ # pylint: enable=line-too-long
if self._has_saved_variables:
raise AssertionError("Graph state including variables and assets has "
"already been saved. Please invoke "
@@ -388,7 +401,8 @@ class SavedModelBuilder(object):
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
- meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+ meta_graph_def = saver.export_meta_graph(
+ clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 92ca7dec6f..1ea619ff55 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -20,13 +20,17 @@ from __future__ import print_function
import os
+from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -36,6 +40,7 @@ from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
@@ -865,6 +870,132 @@ class SavedModelTest(test.TestCase):
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+ def testStripDefaultAttrs(self):
+ export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Add a graph with two float32 variables and a Complex Op composing them
+ # with strip_default_attrs enabled.
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], strip_default_attrs=True)
+
+ # Add a graph with the same float32 variables and a Complex Op composing
+ # them with strip_default_attrs disabled.
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph(["bar"], strip_default_attrs=False)
+
+ # Save the SavedModel to disk in text format.
+ builder.save(as_text=True)
+
+ # Loading graph "foo" via the loader must restore the defaults for the
+ # "Complex" node based on the "Complex" OpDef in the Op registry.
+ sess = session.Session(graph=ops.Graph())
+ meta_graph_def = loader.load(sess, ["foo"], export_dir)
+ complex_node = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertIn("T", complex_node.attr)
+ self.assertIn("Tout", complex_node.attr)
+
+ # Load graph "foo" from disk as-is to verify default attrs are stripped.
+ # pylint: disable=protected-access
+ saved_model_pb = loader_impl._parse_saved_model(export_dir)
+ self.assertIsNotNone(saved_model_pb)
+ # pylint: enable=protected-access
+
+ meta_graph_foo_def = None
+ meta_graph_bar_def = None
+ for meta_graph_def in saved_model_pb.meta_graphs:
+ if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
+ meta_graph_foo_def = meta_graph_def
+ elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
+ meta_graph_bar_def = meta_graph_def
+
+ self.assertIsNotNone(meta_graph_foo_def)
+ self.assertIsNotNone(meta_graph_bar_def)
+
+ # "Complex" Op has 2 attributes with defaults:
+ # o "T" : float32. (input type)
+ # o "Tout" : complex64. (output type)
+
+ # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
+ # Graph "foo" was saved with strip_default_attrs set to True.
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_foo_def.graph_def)
+ self.assertNotIn("T", node_def.attr)
+ self.assertNotIn("Tout", node_def.attr)
+
+ # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
+ # Graph "bar" was saved with strip_default_attrs set to False.
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_bar_def.graph_def)
+ self.assertIn("T", node_def.attr)
+ self.assertIn("Tout", node_def.attr)
+
+ def testStripDefaultAttrsInconsistentConsumerDefaults(self):
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_strip_default_attrs_no_consumer_defaults")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Add a graph with two float32 variables and a Complex Op composing them
+ # with strip_default_attrs enabled. This must remove the following
+ # defaults for the "Complex" Op:
+ # o "T" : float32. (input type)
+ # o "Tout" : complex64. (output type)
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], strip_default_attrs=True)
+
+ # Save the SavedModel to disk in text format.
+ builder.save(as_text=True)
+
+ # Update the Op registry to remove defaults for all attrs("T", "Tout") from
+ # the "Complex" OpDef.
+ complex_op_def = op_def_registry.get_registered_ops()["Complex"]
+ original_complex_op_def = op_def_pb2.OpDef()
+ original_complex_op_def.CopyFrom(complex_op_def)
+ for attr_def in complex_op_def.attr:
+ attr_def.ClearField("default_value")
+
+ # Loading the SavedModel via the loader must fail because the SavedModel
+ # does not have any attr values for the "Complex" node and the current
+ # op registry does not have have any default values for the "Complex" op.
+ sess = session.Session(graph=ops.Graph())
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
+ loader.load(sess, ["foo"], export_dir)
+
+ # Update the Op registry to change the defaults for attr "Tout"
+ # (complex64 -> complex128).
+ complex_op_def.CopyFrom(original_complex_op_def)
+ for attr_def in complex_op_def.attr:
+ if attr_def.name == "Tout":
+ attr_def.default_value.type = types_pb2.DT_COMPLEX128
+
+ # Loading the SavedModel via the loader must set "Tout" attr_value for the
+ # "Complex" node according to the latest defaults (complex128). This is
+ # expected to fail the model import as there is no OpKernel registered to
+ # handle attrs "T" (float32) and "Tout" (complex128).
+ sess = session.Session(graph=ops.Graph())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ ".*No OpKernel was registered to support Op \'Complex\' with these "
+ "attrs..*"):
+ loader.load(sess, ["foo"], export_dir)
+
if __name__ == "__main__":
test.main()