diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-01-02 18:19:21 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-01-02 18:28:12 -0800 |
commit | da1588ccf5c62eccac8013673359ac15b43eb394 (patch) | |
tree | ca08b5eee9e81362aacdd15874cee00f2fb34ca8 /tensorflow/python/saved_model | |
parent | a8041609f182db130953e123d7be4823ea13d05e (diff) |
meta_graph export: Add support to strip default valued attributes.
Following APIs now accept an additional argument (`strip_default_attrs`) to
enable/disable (default:disabled) stripping of default valued attributes in a NodeDef:
o meta_graph: export_meta_graph, create_meta_graph.
o saver: Saver.save, Saver.export_meta_graph.
o builder: SavedModelBuilder.add_meta_graph,
SavedModelBuilder.add_meta_graph_and_variables.
o estimator: Estimator.export_savedmodel.
Related changes:
o Pywrap C++ AreAttrValuesEqual to compare two AttrValue instances.
This allows for a single/canonical way of comparing AttrValues in C++/Python.
o Add a utility method to meta_graph.py to get the node def by name in a graph def.
o Update SavedModelBuilder documentation on relevance of strip_default_attrs flag.
PiperOrigin-RevId: 180619001
Diffstat (limited to 'tensorflow/python/saved_model')
-rw-r--r-- | tensorflow/python/saved_model/README.md | 29 | ||||
-rw-r--r-- | tensorflow/python/saved_model/builder_impl.py | 22 | ||||
-rw-r--r-- | tensorflow/python/saved_model/saved_model_test.py | 131 |
3 files changed, 178 insertions, 4 deletions
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md index 8c78013ffd..5eeaf73a43 100644 --- a/tensorflow/python/saved_model/README.md +++ b/tensorflow/python/saved_model/README.md @@ -117,6 +117,35 @@ with tf.Session(graph=tf.Graph()) as sess: builder.save() ~~~ +#### Stripping Default valued attributes +The SavedModelBuilder class allows users to control whether default-valued +attributes must be stripped from the NodeDefs while adding a meta graph to the +SavedModel bundle. Both `SavedModelBuilder.add_meta_graph_and_variables` and +`SavedModelBuilder.add_meta_graph` methods accept a Boolean flag +`strip_default_attrs` that controls this behavior. + +If `strip_default_attrs` is `False`, the exported MetaGraphDef will have the +default valued attributes in all it's NodeDef instances. This can break forward +compatibility with a sequence of events such as the following: + +* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a + default (`bool`) at version 101. +* A model producer (such as a Trainer) binary picks up this change + (version 101) to the OpDef and re-exports an existing model that uses Op `Foo`. +* A model consumer (such as Tensorflow Serving) running an older binary + (version 100) doesn't have attribute `T` for Op `Foo`, but tries to import + this model. The model consumer doesn't recognize attribute `T` in a NodeDef + that uses Op `Foo` and therefore fails to load the model. + +By setting `strip_default_attrs` to `True`, the model producers can strip away +any default valued attributes in the NodeDefs. This helps ensure that newly +added attributes with defaults don't cause older model consumers to fail loading +models regenerated with newer training binaries. + +TIP: If you care about forward compatibility, then set `strip_default_attrs` +to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and +`SavedModelBuilder.add_meta_graph`. + ### Loader The SavedModel loader is implemented in C++ and Python. diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py index 16651ffebc..62ee53b816 100644 --- a/tensorflow/python/saved_model/builder_impl.py +++ b/tensorflow/python/saved_model/builder_impl.py @@ -239,7 +239,9 @@ class SavedModelBuilder(object): assets_collection=None, legacy_init_op=None, clear_devices=False, - main_op=None): + main_op=None, + strip_default_attrs=False): + # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel. Creates a Saver in the current scope and uses the Saver to export the meta @@ -260,11 +262,15 @@ class SavedModelBuilder(object): main_op: Op or group of ops to execute when the graph is loaded. Note that when the main_op is specified it is run after the restore op at load-time. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). Raises: AssertionError: If the variables for the SavedModel have not been saved yet, or if the graph already contains one or more legacy init ops. """ + # pylint: enable=line-too-long if not self._has_saved_variables: raise AssertionError( "Graph state including variables and assets has not been saved yet. " @@ -299,7 +305,8 @@ class SavedModelBuilder(object): # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) + meta_graph_def = saver.export_meta_graph( + clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) @@ -311,7 +318,9 @@ class SavedModelBuilder(object): assets_collection=None, legacy_init_op=None, clear_devices=False, - main_op=None): + main_op=None, + strip_default_attrs=False): + # pylint: disable=line-too-long """Adds the current meta graph to the SavedModel and saves variables. Creates a Saver to save the variables from the provided session. Exports the @@ -334,7 +343,11 @@ class SavedModelBuilder(object): main_op: Op or group of ops to execute when the graph is loaded. Note that when the main_op is specified it is run after the restore op at load-time. + strip_default_attrs: Boolean. If `True`, default-valued attributes will be + removed from the NodeDefs. For a detailed guide, see + [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes). """ + # pylint: enable=line-too-long if self._has_saved_variables: raise AssertionError("Graph state including variables and assets has " "already been saved. Please invoke " @@ -388,7 +401,8 @@ class SavedModelBuilder(object): # there are edge cases where that option breaks the graph. Until that is # resolved, we just leave the option set to False for now. # TODO(soergel): Reinstate clear_extraneous_savers=True when possible. - meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices) + meta_graph_def = saver.export_meta_graph( + clear_devices=clear_devices, strip_default_attrs=strip_default_attrs) # Tag the meta graph def and add it to the SavedModel. self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map) diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py index 92ca7dec6f..1ea619ff55 100644 --- a/tensorflow/python/saved_model/saved_model_test.py +++ b/tensorflow/python/saved_model/saved_model_test.py @@ -20,13 +20,17 @@ from __future__ import print_function import os +from tensorflow.core.framework import op_def_pb2 from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.core.protobuf import meta_graph_pb2 from tensorflow.python.client import session from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import errors +from tensorflow.python.framework import op_def_registry from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.lib.io import file_io from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops @@ -36,6 +40,7 @@ from tensorflow.python.platform import test from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import constants from tensorflow.python.saved_model import loader +from tensorflow.python.saved_model import loader_impl from tensorflow.python.saved_model import main_op from tensorflow.python.saved_model import signature_def_utils from tensorflow.python.saved_model import tag_constants @@ -865,6 +870,132 @@ class SavedModelTest(test.TestCase): self.assertEqual( 42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval()) + def testStripDefaultAttrs(self): + export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Add a graph with two float32 variables and a Complex Op composing them + # with strip_default_attrs enabled. + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, ["foo"], strip_default_attrs=True) + + # Add a graph with the same float32 variables and a Complex Op composing + # them with strip_default_attrs disabled. + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph(["bar"], strip_default_attrs=False) + + # Save the SavedModel to disk in text format. + builder.save(as_text=True) + + # Loading graph "foo" via the loader must restore the defaults for the + # "Complex" node based on the "Complex" OpDef in the Op registry. + sess = session.Session(graph=ops.Graph()) + meta_graph_def = loader.load(sess, ["foo"], export_dir) + complex_node = test_util.get_node_def_from_graph("complex", + meta_graph_def.graph_def) + self.assertIn("T", complex_node.attr) + self.assertIn("Tout", complex_node.attr) + + # Load graph "foo" from disk as-is to verify default attrs are stripped. + # pylint: disable=protected-access + saved_model_pb = loader_impl._parse_saved_model(export_dir) + self.assertIsNotNone(saved_model_pb) + # pylint: enable=protected-access + + meta_graph_foo_def = None + meta_graph_bar_def = None + for meta_graph_def in saved_model_pb.meta_graphs: + if set(meta_graph_def.meta_info_def.tags) == set(["foo"]): + meta_graph_foo_def = meta_graph_def + elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]): + meta_graph_bar_def = meta_graph_def + + self.assertIsNotNone(meta_graph_foo_def) + self.assertIsNotNone(meta_graph_bar_def) + + # "Complex" Op has 2 attributes with defaults: + # o "T" : float32. (input type) + # o "Tout" : complex64. (output type) + + # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout". + # Graph "foo" was saved with strip_default_attrs set to True. + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_foo_def.graph_def) + self.assertNotIn("T", node_def.attr) + self.assertNotIn("Tout", node_def.attr) + + # "Complex" Op in graph "bar" must have attributes "T" and "Tout". + # Graph "bar" was saved with strip_default_attrs set to False. + node_def = test_util.get_node_def_from_graph("complex", + meta_graph_bar_def.graph_def) + self.assertIn("T", node_def.attr) + self.assertIn("Tout", node_def.attr) + + def testStripDefaultAttrsInconsistentConsumerDefaults(self): + export_dir = os.path.join(test.get_temp_dir(), + "test_strip_default_attrs_no_consumer_defaults") + builder = saved_model_builder.SavedModelBuilder(export_dir) + + # Add a graph with two float32 variables and a Complex Op composing them + # with strip_default_attrs enabled. This must remove the following + # defaults for the "Complex" Op: + # o "T" : float32. (input type) + # o "Tout" : complex64. (output type) + with session.Session(graph=ops.Graph()) as sess: + real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real") + imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag") + math_ops.complex(real_num, imag_num, name="complex") + sess.run(variables.global_variables_initializer()) + builder.add_meta_graph_and_variables( + sess, ["foo"], strip_default_attrs=True) + + # Save the SavedModel to disk in text format. + builder.save(as_text=True) + + # Update the Op registry to remove defaults for all attrs("T", "Tout") from + # the "Complex" OpDef. + complex_op_def = op_def_registry.get_registered_ops()["Complex"] + original_complex_op_def = op_def_pb2.OpDef() + original_complex_op_def.CopyFrom(complex_op_def) + for attr_def in complex_op_def.attr: + attr_def.ClearField("default_value") + + # Loading the SavedModel via the loader must fail because the SavedModel + # does not have any attr values for the "Complex" node and the current + # op registry does not have have any default values for the "Complex" op. + sess = session.Session(graph=ops.Graph()) + with self.assertRaisesRegexp( + ValueError, + "Expected one attr with name .*T(out)?.* in name: \"complex\".*"): + loader.load(sess, ["foo"], export_dir) + + # Update the Op registry to change the defaults for attr "Tout" + # (complex64 -> complex128). + complex_op_def.CopyFrom(original_complex_op_def) + for attr_def in complex_op_def.attr: + if attr_def.name == "Tout": + attr_def.default_value.type = types_pb2.DT_COMPLEX128 + + # Loading the SavedModel via the loader must set "Tout" attr_value for the + # "Complex" node according to the latest defaults (complex128). This is + # expected to fail the model import as there is no OpKernel registered to + # handle attrs "T" (float32) and "Tout" (complex128). + sess = session.Session(graph=ops.Graph()) + with self.assertRaisesRegexp( + errors.InvalidArgumentError, + ".*No OpKernel was registered to support Op \'Complex\' with these " + "attrs..*"): + loader.load(sess, ["foo"], export_dir) + if __name__ == "__main__": test.main() |