aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py11
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py36
-rw-r--r--tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py3
-rw-r--r--tensorflow/core/protobuf/meta_graph.proto4
-rw-r--r--tensorflow/core/public/version.h5
-rw-r--r--tensorflow/python/client/tf_session.i1
-rw-r--r--tensorflow/python/client/tf_session_helper.cc24
-rw-r--r--tensorflow/python/client/tf_session_helper.h7
-rw-r--r--tensorflow/python/estimator/estimator.py9
-rw-r--r--tensorflow/python/estimator/estimator_test.py61
-rw-r--r--tensorflow/python/framework/meta_graph.py81
-rw-r--r--tensorflow/python/framework/meta_graph_test.py103
-rw-r--r--tensorflow/python/framework/test_util.py18
-rw-r--r--tensorflow/python/framework/test_util_test.py7
-rw-r--r--tensorflow/python/saved_model/README.md29
-rw-r--r--tensorflow/python/saved_model/builder_impl.py22
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py131
-rw-r--r--tensorflow/python/training/saver.py29
-rw-r--r--tensorflow/python/training/saver_test.py36
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/tensorflow.train.pbtxt2
32 files changed, 613 insertions, 36 deletions
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index 05ed8b3409..2395c7e717 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -1256,7 +1256,9 @@ class Estimator(BaseEstimator):
assets_extra=None,
as_text=False,
checkpoint_path=None,
- graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)):
+ graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),),
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Exports inference graph as a SavedModel into given dir.
Args:
@@ -1280,6 +1282,9 @@ class Estimator(BaseEstimator):
produce a separate MetaGraphDef within the exported SavedModel, tagged
and rewritten as specified. Defaults to a single entry using the
default serving tag ("serve") and no rewriting.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
@@ -1287,6 +1292,7 @@ class Estimator(BaseEstimator):
Raises:
ValueError: if an unrecognized export_type is requested.
"""
+ # pylint: enable=line-too-long
if serving_input_fn is None:
raise ValueError('serving_input_fn must be defined.')
@@ -1366,7 +1372,8 @@ class Estimator(BaseEstimator):
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
- legacy_init_op=init_op)
+ legacy_init_op=init_op,
+ strip_default_attrs=strip_default_attrs)
# pylint: disable=protected-access
base_meta_graph_def = builder._saved_model.meta_graphs[0]
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
index 4b404a8e20..03ec66b98b 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils.py
@@ -390,7 +390,9 @@ def make_export_strategy(serving_input_fn,
default_output_alternative_key=None,
assets_extra=None,
as_text=False,
- exports_to_keep=5):
+ exports_to_keep=5,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Create an ExportStrategy for use with Experiment.
Args:
@@ -411,10 +413,14 @@ def make_export_strategy(serving_input_fn,
exports_to_keep: Number of exports to keep. Older exports will be
garbage-collected. Defaults to 5. Set to None to disable garbage
collection.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
An ExportStrategy that can be passed to the Experiment constructor.
"""
+ # pylint: enable=line-too-long
def export_fn(estimator, export_dir_base, checkpoint_path=None):
"""Exports the given Estimator as a SavedModel.
@@ -443,7 +449,8 @@ def make_export_strategy(serving_input_fn,
serving_input_fn,
assets_extra=assets_extra,
as_text=as_text,
- checkpoint_path=checkpoint_path)
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs)
else:
export_result = estimator.export_savedmodel(
export_dir_base,
@@ -451,7 +458,8 @@ def make_export_strategy(serving_input_fn,
default_output_alternative_key=default_output_alternative_key,
assets_extra=assets_extra,
as_text=as_text,
- checkpoint_path=checkpoint_path)
+ checkpoint_path=checkpoint_path,
+ strip_default_attrs=strip_default_attrs)
garbage_collect_exports(export_dir_base, exports_to_keep)
return export_result
@@ -464,7 +472,9 @@ def make_parsing_export_strategy(feature_columns,
assets_extra=None,
as_text=False,
exports_to_keep=5,
- target_core=False):
+ target_core=False,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Create an ExportStrategy for use with Experiment, using `FeatureColumn`s.
Creates a SavedModel export that expects to be fed with a single string
@@ -492,10 +502,14 @@ def make_parsing_export_strategy(feature_columns,
target_core: If True, prepare an ExportStrategy for use with
tensorflow.python.estimator.*. If False (default), prepare an
ExportStrategy for use with tensorflow.contrib.learn.python.learn.*.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
An ExportStrategy that can be passed to the Experiment constructor.
"""
+ # pylint: enable=line-too-long
feature_spec = feature_column.create_feature_spec_for_parsing(feature_columns)
if target_core:
serving_input_fn = (
@@ -508,7 +522,8 @@ def make_parsing_export_strategy(feature_columns,
default_output_alternative_key=default_output_alternative_key,
assets_extra=assets_extra,
as_text=as_text,
- exports_to_keep=exports_to_keep)
+ exports_to_keep=exports_to_keep,
+ strip_default_attrs=strip_default_attrs)
def _default_compare_fn(curr_best_eval_result, cand_eval_result):
@@ -584,7 +599,9 @@ class BestModelSelector(object):
def make_best_model_export_strategy(serving_input_fn,
exports_to_keep=1,
compare_fn=None,
- default_output_alternative_key=None):
+ default_output_alternative_key=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Creates an custom ExportStrategy for use with tf.contrib.learn.Experiment.
Args:
@@ -596,14 +613,19 @@ def make_best_model_export_strategy(serving_input_fn,
of evaluation result keyed by corresponding checkpoint path.
default_output_alternative_key: the key for default serving signature for
multi-headed inference graphs.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
An ExportStrategy that can be passed to the Experiment constructor.
"""
+ # pylint: enable=line-too-long
best_model_export_strategy = make_export_strategy(
serving_input_fn,
exports_to_keep=exports_to_keep,
- default_output_alternative_key=default_output_alternative_key)
+ default_output_alternative_key=default_output_alternative_key,
+ strip_default_attrs=strip_default_attrs)
best_model_selector = BestModelSelector(compare_fn)
diff --git a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
index 628eb254c3..531d9c672b 100644
--- a/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
+++ b/tensorflow/contrib/learn/python/learn/utils/saved_model_export_utils_test.py
@@ -55,7 +55,8 @@ class TestEstimator(core_estimator.Estimator):
default_output_alternative_key=None,
assets_extra=None,
as_text=False,
- checkpoint_path=None):
+ checkpoint_path=None,
+ strip_default_attrs=False):
if not os.path.exists(export_dir):
os.makedirs(export_dir)
diff --git a/tensorflow/core/protobuf/meta_graph.proto b/tensorflow/core/protobuf/meta_graph.proto
index 47ec2aa1ef..fd86c0da12 100644
--- a/tensorflow/core/protobuf/meta_graph.proto
+++ b/tensorflow/core/protobuf/meta_graph.proto
@@ -61,6 +61,10 @@ message MetaGraphDef {
// graph. This will be populated by the framework, which will overwrite any
// user supplied value.
string tensorflow_git_version = 6;
+
+ // A flag to denote whether default-valued attrs have been stripped from
+ // the nodes in this graph_def.
+ bool stripped_default_attrs = 7;
}
MetaInfoDef meta_info_def = 1;
diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h
index d8e7df48c2..c037a9b122 100644
--- a/tensorflow/core/public/version.h
+++ b/tensorflow/core/public/version.h
@@ -91,10 +91,13 @@ limitations under the License.
// 24. Deprecate lookup ops (v1) ops in favor of v2 (30may2017)
// 25. Deprecate stack (v1) ops in favor of v2 (2017/6/15).
// 25. Deprecate RandomPoisson (v1) ops in favor of v2 (2017/10/25).
+// 26. Add a bool 'stripped_default_attrs' to MetaInfoDef indicating
+// whether default-valued attrs have been stripped from the nodes in the
+// GraphDef. (7dec2017)
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
-#define TF_GRAPH_DEF_VERSION 24
+#define TF_GRAPH_DEF_VERSION 25
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 3f1d63a543..1fd488e7b6 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -485,6 +485,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
%unignore tensorflow;
%unignore TF_Run;
%unignore EqualGraphDefWrapper;
+%unignore EqualAttrValueWrapper;
// Include the wrapper for TF_PRunSetup from tf_session_helper.h.
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index 2b83141faa..361dbc22b0 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -21,10 +21,13 @@ limitations under the License.
#include "tensorflow/c/c_api_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/framework/allocator.h"
+#include "tensorflow/core/framework/attr_value.pb.h"
+#include "tensorflow/core/framework/attr_value_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/graph/tensor_id.h"
#include "tensorflow/core/lib/core/coding.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/equal_graph_def.h"
#include "tensorflow/python/lib/core/ndarray_tensor.h"
@@ -301,6 +304,27 @@ string EqualGraphDefWrapper(const string& actual, const string& expected) {
return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff;
}
+string EqualAttrValueWrapper(const string& actual, const string& expected) {
+ AttrValue actual_attr_value;
+ if (!actual_attr_value.ParseFromString(actual)) {
+ return "actual is not a valid serialized AttrValue";
+ }
+
+ AttrValue expected_attr_value;
+ if (!expected_attr_value.ParseFromString(expected)) {
+ return "expected is not a valid serialized AttrValue";
+ }
+
+ string diff;
+ if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) {
+ diff = strings::Printf(
+ "Actual AttrValue %s does not match Expected AttrValue %s.",
+ SummarizeAttrValue(actual_attr_value).c_str(),
+ SummarizeAttrValue(expected_attr_value).c_str());
+ }
+ return diff;
+}
+
// Return value set to 6 inlined elements so it fits in a 64-byte cache line.
tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper(
TF_Graph* graph, TF_Output output, TF_Status* out_status,
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 8f2499b9a0..29d5b28f40 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -97,6 +97,13 @@ void TF_Reset_wrapper(const TF_SessionOptions* opt,
// for no difference.
string EqualGraphDefWrapper(const string& actual, const string& expected);
+// Convenience wrapper around AreAttrValuesEqual to make it easier to wrap.
+// The actual and expected strings must correspond to a serialized binary
+// representation of two AttrValue proto instances.
+// Returns an explanation if a difference is found, or the empty string
+// for no difference.
+string EqualAttrValueWrapper(const string& actual, const string& expected);
+
// Gets shape from C API Graph object.
//
// If shape is known, returns shape vector where -1 means "unknown
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 1e3d6d5755..c72d37b442 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -461,7 +461,8 @@ class Estimator(object):
self, export_dir_base, serving_input_receiver_fn,
assets_extra=None,
as_text=False,
- checkpoint_path=None):
+ checkpoint_path=None,
+ strip_default_attrs=False):
# pylint: disable=line-too-long
"""Exports inference graph as a SavedModel into given dir.
@@ -503,6 +504,9 @@ class Estimator(object):
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If `None` (the default),
the most recent checkpoint found within the model directory is chosen.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
The string path to the exported directory.
@@ -563,7 +567,8 @@ class Estimator(object):
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
- legacy_init_op=local_init_op)
+ legacy_init_op=local_init_op,
+ strip_default_attrs=strip_default_attrs)
builder.save(as_text)
# Add the extra assets
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index db64fbc9cc..58d0cb0018 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -40,6 +40,7 @@ from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.layers import layers
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import array_ops
@@ -57,6 +58,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.summary import summary
from tensorflow.python.summary import summary_iterator
@@ -2050,6 +2052,65 @@ class EstimatorExportTest(test.TestCase):
gfile.DeleteRecursively(tmpdir)
+ def test_export_savedmodel_proto_strip_default_attrs(self):
+ tmpdir = tempfile.mkdtemp()
+ est = estimator.Estimator(model_fn=_model_fn_for_export_tests)
+ est.train(input_fn=dummy_input_fn, steps=1)
+ feature_spec = {'x': parsing_ops.VarLenFeature(dtype=dtypes.int64),
+ 'y': parsing_ops.VarLenFeature(dtype=dtypes.int64)}
+ serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
+ feature_spec)
+
+ # Perform the export.
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir_stripped = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn, strip_default_attrs=True)
+ export_dir_not_stripped = est.export_savedmodel(
+ export_dir_base, serving_input_receiver_fn, strip_default_attrs=False)
+
+ # Load the SavedModel from disk as-is to verify default attrs
+ # are stripped. Reimporting the SavedModel via the loader causes the
+ # default attrs to be populated in the NodeDefs.
+
+ # pylint: disable=protected-access
+ saved_model_stripped_pb = loader_impl._parse_saved_model(
+ export_dir_stripped)
+ saved_model_not_stripped_pb = loader_impl._parse_saved_model(
+ export_dir_not_stripped)
+ self.assertIsNotNone(saved_model_stripped_pb)
+ self.assertIsNotNone(saved_model_not_stripped_pb)
+ # pylint: enable=protected-access
+
+ meta_graph_def_stripped = [
+ x for x in saved_model_stripped_pb.meta_graphs
+ if x.meta_info_def.tags == [tag_constants.SERVING]][0]
+ meta_graph_def_not_stripped = [
+ x for x in saved_model_not_stripped_pb.meta_graphs
+ if x.meta_info_def.tags == [tag_constants.SERVING]][0]
+
+ # "weight" node in graph is a "Variable" Op with 2 default valued attrs.
+ # o "container" : "".
+ # o "shared_name" : "".
+
+ # saved_model_stripped_pb was exported with strip_default_attrs set to True.
+ # "weight" node shouldn't have attributes "container" and "shared_name".
+ node_def = test_util.get_node_def_from_graph(
+ 'weight', meta_graph_def_stripped.graph_def)
+ self.assertNotIn('container', node_def.attr)
+ self.assertNotIn('shared_name', node_def.attr)
+
+ # saved_model_not_stripped_pb was exported with strip_default_attrs
+ # disabled. "weight" node should have attributes "container" and
+ # "shared_name".
+ node_def = test_util.get_node_def_from_graph(
+ 'weight', meta_graph_def_not_stripped.graph_def)
+ self.assertIn('container', node_def.attr)
+ self.assertIn('shared_name', node_def.attr)
+
+ # Clean up.
+ gfile.DeleteRecursively(tmpdir)
+
class EstimatorHookOrderingTest(test.TestCase):
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
index c839d7a9a6..65032637d4 100644
--- a/tensorflow/python/framework/meta_graph.py
+++ b/tensorflow/python/framework/meta_graph.py
@@ -31,6 +31,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import graph_io
from tensorflow.python.framework import importer
@@ -442,6 +443,67 @@ def add_collection_def(meta_graph_def, key, graph=None,
return
+def _is_default_attr_value(op_def, attr_name, attr_value):
+ """Checks if given attribute matches the default value in the op def."""
+ for attr_def in op_def.attr:
+ if attr_def.name == attr_name:
+ if not attr_def.HasField("default_value"):
+ return False
+ # pywrap_tensorflow.EqualAttrValueWrapper returns an empty string
+ # if both arguments represent an equivalent AttrValue instance.
+ return not pywrap_tensorflow.EqualAttrValueWrapper(
+ attr_value.SerializeToString(),
+ attr_def.default_value.SerializeToString())
+ return False
+
+
+def _strip_graph_default_valued_attrs(meta_graph_def):
+ """Strips default valued attributes for node defs in given MetaGraphDef.
+
+ This method also sets `meta_info_def.stripped_default_attrs` in the given
+ `MetaGraphDef` proto to True.
+
+ Args:
+ meta_graph_def: `MetaGraphDef` protocol buffer
+
+ Returns:
+ None.
+ """
+ # Map function op names to their function definitions.
+ op_name_to_function = {}
+ for function_def in meta_graph_def.graph_def.library.function:
+ op_name_to_function[function_def.signature.name] = function_def
+
+ # Get all registered ops.
+ registered_ops = op_def_registry.get_registered_ops()
+
+ def _strip_node_default_valued_attrs(node_def):
+ """Removes default valued attributes from a single node def."""
+ if node_def.op in op_name_to_function or node_def.op not in registered_ops:
+ return
+ op_def = registered_ops[node_def.op]
+
+ attrs_to_strip = set()
+ for attr_name, attr_value in node_def.attr.items():
+ if _is_default_attr_value(op_def, attr_name, attr_value):
+ attrs_to_strip.add(attr_name)
+
+ for attr in attrs_to_strip:
+ del node_def.attr[attr]
+
+ # Process all NodeDef instances in graph_def.
+ for node_def in meta_graph_def.graph_def.node:
+ _strip_node_default_valued_attrs(node_def)
+
+ # Process all NodeDef instances in graph_def.library.function.
+ for function_def in meta_graph_def.graph_def.library.function:
+ for function_node_def in function_def.node_def:
+ _strip_node_default_valued_attrs(function_node_def)
+
+ # Tell consumers of this graph that default valued attrs have been stripped.
+ meta_graph_def.meta_info_def.stripped_default_attrs = True
+
+
def create_meta_graph_def(meta_info_def=None,
graph_def=None,
saver_def=None,
@@ -449,7 +511,9 @@ def create_meta_graph_def(meta_info_def=None,
graph=None,
export_scope=None,
exclude_nodes=None,
- clear_extraneous_savers=False):
+ clear_extraneous_savers=False,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Construct and returns a `MetaGraphDef` protocol buffer.
Args:
@@ -464,12 +528,17 @@ def create_meta_graph_def(meta_info_def=None,
clear_extraneous_savers: Remove any preexisting SaverDefs from the SAVERS
collection. Note this method does not alter the graph, so any
extraneous Save/Restore ops should have been removed already, as needed.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
+
Returns:
MetaGraphDef protocol buffer.
Raises:
TypeError: If the arguments are not of the correct proto buffer type.
"""
+ # pylint: enable=line-too-long
# Type check.
if graph and not isinstance(graph, ops.Graph):
raise TypeError("graph must be of type Graph, not %s", type(graph))
@@ -511,6 +580,10 @@ def create_meta_graph_def(meta_info_def=None,
stripped_op_list_for_graph(meta_graph_def.graph_def))
# pylint: enable=g-explicit-length-test
+ # Strip default valued attributes in graph_def.
+ if strip_default_attrs:
+ _strip_graph_default_valued_attrs(meta_graph_def)
+
# Adds saver_def.
if saver_def:
meta_graph_def.saver_def.MergeFrom(saver_def)
@@ -724,6 +797,7 @@ def export_scoped_meta_graph(filename=None,
clear_devices=False,
saver_def=None,
clear_extraneous_savers=False,
+ strip_default_attrs=False,
**kwargs):
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
@@ -752,6 +826,8 @@ def export_scoped_meta_graph(filename=None,
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
+ strip_default_attrs: Set to true if default valued attributes must be
+ removed while exporting the GraphDef.
**kwargs: Optional keyed arguments, including meta_info_def and
collection_list.
@@ -837,6 +913,7 @@ def export_scoped_meta_graph(filename=None,
exclude_nodes=exclude_nodes,
clear_extraneous_savers=clear_extraneous_savers,
saver_def=saver_def,
+ strip_default_attrs=strip_default_attrs,
**kwargs)
if filename:
@@ -881,3 +958,5 @@ def copy_scoped_meta_graph(from_scope, to_scope,
graph=to_graph,
import_scope=to_scope)
return var_list
+
+
diff --git a/tensorflow/python/framework/meta_graph_test.py b/tensorflow/python/framework/meta_graph_test.py
index 4c22c913b8..ae8c9ea2a4 100644
--- a/tensorflow/python/framework/meta_graph_test.py
+++ b/tensorflow/python/framework/meta_graph_test.py
@@ -24,6 +24,7 @@ import random
import shutil
from tensorflow.core.framework import graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
@@ -154,6 +155,108 @@ class SimpleMetaGraphTest(test.TestCase):
op_list = meta_graph.stripped_op_list_for_graph(graph)
self.assertEqual(["Const"], [op.name for op in op_list.op])
+ def testDefaultAttrStripping(self):
+ """Verifies that default attributes are stripped from a graph def."""
+
+ # Complex Op has 2 attributes with defaults:
+ # o "T" : float32.
+ # o "Tout" : complex64.
+
+ # When inputs to the Complex Op are float32 instances, "T" maps to float32
+ # and "Tout" maps to complex64. Since these attr values map to their
+ # defaults, they must be stripped unless stripping of default attrs is
+ # disabled.
+ with self.test_session():
+ real_num = constant_op.constant(1.0, dtype=dtypes.float32, name="real")
+ imag_num = constant_op.constant(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+
+ # strip_default_attrs is enabled.
+ meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+ graph_def=ops.get_default_graph().as_graph_def(),
+ strip_default_attrs=True)
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertNotIn("T", node_def.attr)
+ self.assertNotIn("Tout", node_def.attr)
+ self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+ # strip_default_attrs is disabled.
+ meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+ graph_def=ops.get_default_graph().as_graph_def(),
+ strip_default_attrs=False)
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertIn("T", node_def.attr)
+ self.assertIn("Tout", node_def.attr)
+ self.assertFalse(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+ # When inputs to the Complex Op are float64 instances, "T" maps to float64
+ # and "Tout" maps to complex128. Since these attr values don't map to their
+ # defaults, they must not be stripped.
+ with self.test_session(graph=ops.Graph()):
+ real_num = constant_op.constant(1.0, dtype=dtypes.float64, name="real")
+ imag_num = constant_op.constant(2.0, dtype=dtypes.float64, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+ graph_def=ops.get_default_graph().as_graph_def(),
+ strip_default_attrs=True)
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertEqual(node_def.attr["T"].type, dtypes.float64)
+ self.assertEqual(node_def.attr["Tout"].type, dtypes.complex128)
+ self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+ def testDefaultAttrStrippingNestedFunctions(self):
+ """Verifies that default attributes are stripped from function node defs."""
+ with self.test_session():
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def f0(i, j):
+ return math_ops.complex(i, j, name="double_nested_complex")
+
+ @function.Defun(dtypes.float32, dtypes.float32)
+ def f1(i, j):
+ return f0(i, j)
+
+ _ = f1(constant_op.constant(1.0), constant_op.constant(2.0))
+ meta_graph_def, _ = meta_graph.export_scoped_meta_graph(
+ graph_def=ops.get_default_graph().as_graph_def(),
+ strip_default_attrs=True)
+
+ double_nested_complex_node_def = None
+ for function_def in meta_graph_def.graph_def.library.function:
+ for node_def in function_def.node_def:
+ if node_def.name == "double_nested_complex":
+ double_nested_complex_node_def = node_def
+ break
+ if double_nested_complex_node_def:
+ break
+
+ self.assertIsNotNone(double_nested_complex_node_def)
+ self.assertNotIn("T", double_nested_complex_node_def.attr)
+ self.assertNotIn("Tout", double_nested_complex_node_def.attr)
+ self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
+ def testDefaultAttrStrippingUnregisteredOps(self):
+ """Verifies that nodes with un-registered ops are not stripped."""
+ graph_def = graph_pb2.GraphDef()
+ node = graph_def.node.add()
+ node.name = "node_with_unreg_op"
+ node.op = "unreg_op"
+ node.attr["attr_1"].i = 1
+
+ meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
+ meta_info_def.stripped_op_list.op.add()
+
+ with self.test_session():
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ meta_info_def=meta_info_def, graph_def=graph_def,
+ strip_default_attrs=True)
+ node_def = test_util.get_node_def_from_graph("node_with_unreg_op",
+ meta_graph_def.graph_def)
+ self.assertEqual(node_def.attr["attr_1"].i, 1)
+ self.assertTrue(meta_graph_def.meta_info_def.stripped_default_attrs)
+
class ScopedMetaGraphTest(test.TestCase):
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 7627fb3e69..5ac3053749 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -1369,3 +1369,21 @@ def create_local_cluster(num_workers, num_ps, protocol="grpc",
]
return workers, ps_servers
+
+
+def get_node_def_from_graph(node_name, graph_def):
+ """Returns the `NodeDef` instance for given node name in the graph def.
+
+ This method explores only the NodeDefs in `graph_def.node`.
+
+ Args:
+ node_name: Name of the NodeDef to search for.
+ graph_def: An instance of `GraphDef` proto.
+
+ Returns:
+ the `NodeDef` instance whose name field matches the given node_name or None.
+ """
+ for node_def in graph_def.node:
+ if node_def.name == node_name:
+ return node_def
+ return None
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index f6aed118ca..4af717cca6 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -349,6 +349,13 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self.assertEqual(expected, self.evaluate(nested))
+ def test_get_node_def_from_graph(self):
+ graph_def = graph_pb2.GraphDef()
+ node_foo = graph_def.node.add()
+ node_foo.name = "foo"
+ self.assertIs(test_util.get_node_def_from_graph("foo", graph_def), node_foo)
+ self.assertIsNone(test_util.get_node_def_from_graph("bar", graph_def))
+
class GarbageCollectionTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/saved_model/README.md b/tensorflow/python/saved_model/README.md
index 8c78013ffd..5eeaf73a43 100644
--- a/tensorflow/python/saved_model/README.md
+++ b/tensorflow/python/saved_model/README.md
@@ -117,6 +117,35 @@ with tf.Session(graph=tf.Graph()) as sess:
builder.save()
~~~
+#### Stripping Default valued attributes
+The SavedModelBuilder class allows users to control whether default-valued
+attributes must be stripped from the NodeDefs while adding a meta graph to the
+SavedModel bundle. Both `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph` methods accept a Boolean flag
+`strip_default_attrs` that controls this behavior.
+
+If `strip_default_attrs` is `False`, the exported MetaGraphDef will have the
+default valued attributes in all it's NodeDef instances. This can break forward
+compatibility with a sequence of events such as the following:
+
+* An existing Op (`Foo`) is updated to include a new attribute (`T`) with a
+ default (`bool`) at version 101.
+* A model producer (such as a Trainer) binary picks up this change
+ (version 101) to the OpDef and re-exports an existing model that uses Op `Foo`.
+* A model consumer (such as Tensorflow Serving) running an older binary
+ (version 100) doesn't have attribute `T` for Op `Foo`, but tries to import
+ this model. The model consumer doesn't recognize attribute `T` in a NodeDef
+ that uses Op `Foo` and therefore fails to load the model.
+
+By setting `strip_default_attrs` to `True`, the model producers can strip away
+any default valued attributes in the NodeDefs. This helps ensure that newly
+added attributes with defaults don't cause older model consumers to fail loading
+models regenerated with newer training binaries.
+
+TIP: If you care about forward compatibility, then set `strip_default_attrs`
+to `True` while using `SavedModelBuilder.add_meta_graph_and_variables` and
+`SavedModelBuilder.add_meta_graph`.
+
### Loader
The SavedModel loader is implemented in C++ and Python.
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 16651ffebc..62ee53b816 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -239,7 +239,9 @@ class SavedModelBuilder(object):
assets_collection=None,
legacy_init_op=None,
clear_devices=False,
- main_op=None):
+ main_op=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel.
Creates a Saver in the current scope and uses the Saver to export the meta
@@ -260,11 +262,15 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at
load-time.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Raises:
AssertionError: If the variables for the SavedModel have not been saved
yet, or if the graph already contains one or more legacy init ops.
"""
+ # pylint: enable=line-too-long
if not self._has_saved_variables:
raise AssertionError(
"Graph state including variables and assets has not been saved yet. "
@@ -299,7 +305,8 @@ class SavedModelBuilder(object):
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
- meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+ meta_graph_def = saver.export_meta_graph(
+ clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
@@ -311,7 +318,9 @@ class SavedModelBuilder(object):
assets_collection=None,
legacy_init_op=None,
clear_devices=False,
- main_op=None):
+ main_op=None,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Adds the current meta graph to the SavedModel and saves variables.
Creates a Saver to save the variables from the provided session. Exports the
@@ -334,7 +343,11 @@ class SavedModelBuilder(object):
main_op: Op or group of ops to execute when the graph is loaded. Note
that when the main_op is specified it is run after the restore op at
load-time.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
"""
+ # pylint: enable=line-too-long
if self._has_saved_variables:
raise AssertionError("Graph state including variables and assets has "
"already been saved. Please invoke "
@@ -388,7 +401,8 @@ class SavedModelBuilder(object):
# there are edge cases where that option breaks the graph. Until that is
# resolved, we just leave the option set to False for now.
# TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
- meta_graph_def = saver.export_meta_graph(clear_devices=clear_devices)
+ meta_graph_def = saver.export_meta_graph(
+ clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
# Tag the meta graph def and add it to the SavedModel.
self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 92ca7dec6f..1ea619ff55 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -20,13 +20,17 @@ from __future__ import print_function
import os
+from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
+from tensorflow.python.framework import op_def_registry
from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import file_io
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
@@ -36,6 +40,7 @@ from tensorflow.python.platform import test
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
+from tensorflow.python.saved_model import loader_impl
from tensorflow.python.saved_model import main_op
from tensorflow.python.saved_model import signature_def_utils
from tensorflow.python.saved_model import tag_constants
@@ -865,6 +870,132 @@ class SavedModelTest(test.TestCase):
self.assertEqual(
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
+ def testStripDefaultAttrs(self):
+ export_dir = os.path.join(test.get_temp_dir(), "test_strip_default_attrs")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Add a graph with two float32 variables and a Complex Op composing them
+ # with strip_default_attrs enabled.
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], strip_default_attrs=True)
+
+ # Add a graph with the same float32 variables and a Complex Op composing
+ # them with strip_default_attrs disabled.
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph(["bar"], strip_default_attrs=False)
+
+ # Save the SavedModel to disk in text format.
+ builder.save(as_text=True)
+
+ # Loading graph "foo" via the loader must restore the defaults for the
+ # "Complex" node based on the "Complex" OpDef in the Op registry.
+ sess = session.Session(graph=ops.Graph())
+ meta_graph_def = loader.load(sess, ["foo"], export_dir)
+ complex_node = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertIn("T", complex_node.attr)
+ self.assertIn("Tout", complex_node.attr)
+
+ # Load graph "foo" from disk as-is to verify default attrs are stripped.
+ # pylint: disable=protected-access
+ saved_model_pb = loader_impl._parse_saved_model(export_dir)
+ self.assertIsNotNone(saved_model_pb)
+ # pylint: enable=protected-access
+
+ meta_graph_foo_def = None
+ meta_graph_bar_def = None
+ for meta_graph_def in saved_model_pb.meta_graphs:
+ if set(meta_graph_def.meta_info_def.tags) == set(["foo"]):
+ meta_graph_foo_def = meta_graph_def
+ elif set(meta_graph_def.meta_info_def.tags) == set(["bar"]):
+ meta_graph_bar_def = meta_graph_def
+
+ self.assertIsNotNone(meta_graph_foo_def)
+ self.assertIsNotNone(meta_graph_bar_def)
+
+ # "Complex" Op has 2 attributes with defaults:
+ # o "T" : float32. (input type)
+ # o "Tout" : complex64. (output type)
+
+ # "Complex" Op in graph "foo" shouldn't have attributes "T" and "Tout".
+ # Graph "foo" was saved with strip_default_attrs set to True.
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_foo_def.graph_def)
+ self.assertNotIn("T", node_def.attr)
+ self.assertNotIn("Tout", node_def.attr)
+
+ # "Complex" Op in graph "bar" must have attributes "T" and "Tout".
+ # Graph "bar" was saved with strip_default_attrs set to False.
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_bar_def.graph_def)
+ self.assertIn("T", node_def.attr)
+ self.assertIn("Tout", node_def.attr)
+
+ def testStripDefaultAttrsInconsistentConsumerDefaults(self):
+ export_dir = os.path.join(test.get_temp_dir(),
+ "test_strip_default_attrs_no_consumer_defaults")
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Add a graph with two float32 variables and a Complex Op composing them
+ # with strip_default_attrs enabled. This must remove the following
+ # defaults for the "Complex" Op:
+ # o "T" : float32. (input type)
+ # o "Tout" : complex64. (output type)
+ with session.Session(graph=ops.Graph()) as sess:
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+ sess.run(variables.global_variables_initializer())
+ builder.add_meta_graph_and_variables(
+ sess, ["foo"], strip_default_attrs=True)
+
+ # Save the SavedModel to disk in text format.
+ builder.save(as_text=True)
+
+ # Update the Op registry to remove defaults for all attrs("T", "Tout") from
+ # the "Complex" OpDef.
+ complex_op_def = op_def_registry.get_registered_ops()["Complex"]
+ original_complex_op_def = op_def_pb2.OpDef()
+ original_complex_op_def.CopyFrom(complex_op_def)
+ for attr_def in complex_op_def.attr:
+ attr_def.ClearField("default_value")
+
+ # Loading the SavedModel via the loader must fail because the SavedModel
+ # does not have any attr values for the "Complex" node and the current
+ # op registry does not have have any default values for the "Complex" op.
+ sess = session.Session(graph=ops.Graph())
+ with self.assertRaisesRegexp(
+ ValueError,
+ "Expected one attr with name .*T(out)?.* in name: \"complex\".*"):
+ loader.load(sess, ["foo"], export_dir)
+
+ # Update the Op registry to change the defaults for attr "Tout"
+ # (complex64 -> complex128).
+ complex_op_def.CopyFrom(original_complex_op_def)
+ for attr_def in complex_op_def.attr:
+ if attr_def.name == "Tout":
+ attr_def.default_value.type = types_pb2.DT_COMPLEX128
+
+ # Loading the SavedModel via the loader must set "Tout" attr_value for the
+ # "Complex" node according to the latest defaults (complex128). This is
+ # expected to fail the model import as there is no OpKernel registered to
+ # handle attrs "T" (float32) and "Tout" (complex128).
+ sess = session.Session(graph=ops.Graph())
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ ".*No OpKernel was registered to support Op \'Complex\' with these "
+ "attrs..*"):
+ loader.load(sess, ["foo"], export_dir)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py
index ba6301e785..2330229d56 100644
--- a/tensorflow/python/training/saver.py
+++ b/tensorflow/python/training/saver.py
@@ -1509,7 +1509,9 @@ class Saver(object):
latest_filename=None,
meta_graph_suffix="meta",
write_meta_graph=True,
- write_state=True):
+ write_state=True,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Saves variables.
This method runs the ops added by the constructor for saving variables.
@@ -1535,6 +1537,9 @@ class Saver(object):
graph file.
write_state: `Boolean` indicating whether or not to write the
`CheckpointStateProto`.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
A string: path prefix used for the checkpoint files. If the saver is
@@ -1548,6 +1553,7 @@ class Saver(object):
collides with `save_path`.
RuntimeError: If save and restore ops weren't built.
"""
+ # pylint: enable=line-too-long
if not self._is_built and context.in_graph_mode():
raise RuntimeError(
"`build()` should be called before save if defer_build==True")
@@ -1618,7 +1624,8 @@ class Saver(object):
checkpoint_file, meta_graph_suffix=meta_graph_suffix)
if context.in_graph_mode():
with sess.graph.as_default():
- self.export_meta_graph(meta_graph_filename)
+ self.export_meta_graph(
+ meta_graph_filename, strip_default_attrs=strip_default_attrs)
if self._is_empty:
return None
@@ -1631,7 +1638,9 @@ class Saver(object):
as_text=False,
export_scope=None,
clear_devices=False,
- clear_extraneous_savers=False):
+ clear_extraneous_savers=False,
+ strip_default_attrs=False):
+ # pylint: disable=line-too-long
"""Writes `MetaGraphDef` to save_path/filename.
Args:
@@ -1644,10 +1653,14 @@ class Saver(object):
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with this Saver.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
Returns:
A `MetaGraphDef` proto.
"""
+ # pylint: enable=line-too-long
return export_meta_graph(
filename=filename,
graph_def=ops.get_default_graph().as_graph_def(add_shapes=True),
@@ -1656,7 +1669,8 @@ class Saver(object):
as_text=as_text,
export_scope=export_scope,
clear_devices=clear_devices,
- clear_extraneous_savers=clear_extraneous_savers)
+ clear_extraneous_savers=clear_extraneous_savers,
+ strip_default_attrs=strip_default_attrs)
def restore(self, sess, save_path):
"""Restores previously saved variables.
@@ -1859,7 +1873,9 @@ def export_meta_graph(filename=None,
export_scope=None,
clear_devices=False,
clear_extraneous_savers=False,
+ strip_default_attrs=False,
**kwargs):
+ # pylint: disable=line-too-long
"""Returns `MetaGraphDef` proto. Optionally writes it to filename.
This function exports the graph, saver, and collection objects into
@@ -1885,6 +1901,9 @@ def export_meta_graph(filename=None,
clear_extraneous_savers: Remove any Saver-related information from the
graph (both Save/Restore ops and SaverDefs) that are not associated
with the provided SaverDef.
+ strip_default_attrs: Boolean. If `True`, default-valued attributes will be
+ removed from the NodeDefs. For a detailed guide, see
+ [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
**kwargs: Optional keyed arguments.
Returns:
@@ -1899,6 +1918,7 @@ def export_meta_graph(filename=None,
execution is enabled.
@end_compatibility
"""
+ # pylint: enable=line-too-long
if context.in_eager_mode():
raise RuntimeError("Exporting/importing meta graphs is not supported when "
"eager execution is enabled. No graph exists when eager "
@@ -1914,6 +1934,7 @@ def export_meta_graph(filename=None,
export_scope=export_scope,
clear_devices=clear_devices,
clear_extraneous_savers=clear_extraneous_savers,
+ strip_default_attrs=strip_default_attrs,
**kwargs)
return meta_graph_def
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 207e4a2842..0889ac2516 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -2065,6 +2065,42 @@ class MetaGraphTest(test.TestCase):
self.assertEqual(o.summary, "")
self.assertEqual(o.description, "")
+ def testStripDefaultValuedAttrs(self):
+ """Verifies that default valued attrs are stripped, unless disabled."""
+
+ # With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
+ # (complex64) in the "Complex" op must be removed.
+ with self.test_session():
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+
+ save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
+ variables.global_variables_initializer()
+
+ meta_graph_def = save.export_meta_graph(strip_default_attrs=True)
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertNotIn("T", node_def.attr)
+ self.assertNotIn("Tout", node_def.attr)
+
+ # With strip_default_attrs disabled, attributes "T" (float32) and "Tout"
+ # (complex64) in the "Complex" op must *not* be removed, even if they map
+ # to their defaults.
+ with self.test_session(graph=ops_lib.Graph()):
+ real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ math_ops.complex(real_num, imag_num, name="complex")
+
+ save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
+ variables.global_variables_initializer()
+
+ meta_graph_def = save.export_meta_graph(strip_default_attrs=False)
+ node_def = test_util.get_node_def_from_graph("complex",
+ meta_graph_def.graph_def)
+ self.assertIn("T", node_def.attr)
+ self.assertIn("Tout", node_def.attr)
+
def testImportIntoNamescope(self):
# Test that we can import a meta graph into a namescope.
test_dir = self._get_test_dir("import_into_namescope")
diff --git a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
index ebf49f434a..b0e9831154 100644
--- a/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.-meta-graph-def.-meta-info-def.pbtxt
@@ -19,6 +19,10 @@ tf_class {
mtype: "<type \'int\'>"
}
member {
+ name: "STRIPPED_DEFAULT_ATTRS_FIELD_NUMBER"
+ mtype: "<type \'int\'>"
+ }
+ member {
name: "STRIPPED_OP_LIST_FIELD_NUMBER"
mtype: "<type \'int\'>"
}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
index f5ed263f0e..ab697b1b95 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-classifier.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
index 61a29942c5..b73f6433e2 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-baseline-regressor.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
index 16e3b24615..24db86c92b 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-classifier.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
index c6765ae277..47ee2ac51b 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-classifier.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
index e3a820db46..fbfaa69a1b 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-linear-combined-regressor.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
index a4c8cf6671..faf55cda86 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-d-n-n-regressor.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
index 787952eced..d0bf043754 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-estimator.pbtxt
@@ -28,7 +28,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
index 99c03aa629..6aec1d3a51 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-classifier.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
index e2ab96d5b4..9d8c7bb138 100644
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.estimator.-linear-regressor.pbtxt
@@ -29,7 +29,7 @@ tf_class {
}
member_method {
name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "get_variable_names"
diff --git a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
index 56d76902fd..ca8e5884b1 100644
--- a/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.saved_model.builder.-saved-model-builder.pbtxt
@@ -8,11 +8,11 @@ tf_class {
}
member_method {
name: "add_meta_graph"
- argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "add_meta_graph_and_variables"
- argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "save"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt
index 04c11712cd..2cda458f46 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.-saver.pbtxt
@@ -20,7 +20,7 @@ tf_class {
}
member_method {
name: "export_meta_graph"
- argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'filename\', \'collection_list\', \'as_text\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'None\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "from_proto"
@@ -36,7 +36,7 @@ tf_class {
}
member_method {
name: "save"
- argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\'], "
+ argspec: "args=[\'self\', \'sess\', \'save_path\', \'global_step\', \'latest_filename\', \'meta_graph_suffix\', \'write_meta_graph\', \'write_state\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'meta\', \'True\', \'True\', \'False\'], "
}
member_method {
name: "set_last_checkpoints"
diff --git a/tensorflow/tools/api/golden/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
index 3ffc640730..b2ef17b39e 100644
--- a/tensorflow/tools/api/golden/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.train.pbtxt
@@ -282,7 +282,7 @@ tf_module {
}
member_method {
name: "export_meta_graph"
- argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\'], "
+ argspec: "args=[\'filename\', \'meta_info_def\', \'graph_def\', \'saver_def\', \'collection_list\', \'as_text\', \'graph\', \'export_scope\', \'clear_devices\', \'clear_extraneous_savers\', \'strip_default_attrs\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'False\', \'None\', \'None\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "generate_checkpoint_state_proto"