aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar David Soergel <soergel@google.com>2017-07-28 09:00:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-07-28 09:04:50 -0700
commit34c49b133dab0becf25da6ce7199131e6f726a04 (patch)
tree7e4b8fb38739c5adb012c29db5bee2bfca83e882 /tensorflow
parent123f84b8e0f4df5ecd97b005876f4b1cd1594ce9 (diff)
Provide MetaGraphDef transformer, and call it from SavedModel export.
A MetaGraphDef transformation consists of a GraphDef transformation provided by the Graph Transform Tool, followed by some surgery at the MetaGraphDef level to remove references to any nodes that were removed. This allows users to request Graph Transform Tool rewrites integrated with Estimator.export_savedmodel(). It also integrates graph freezing interleaved with those rewrites (even though that is not provided by GTT). A limitation, for now, is that all Variables and many of their associated Save and Restore Ops are retained even if they are unused and strip_unused_nodes is requested (pending further clarity on which ones may be safe to remove). PiperOrigin-RevId: 163475341
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/BUILD1
-rwxr-xr-xtensorflow/contrib/BUILD1
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake3
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/__init__.py1
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator.py132
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py124
-rw-r--r--tensorflow/contrib/meta_graph_transform/BUILD59
-rw-r--r--tensorflow/contrib/meta_graph_transform/__init__.py28
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py453
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py202
12 files changed, 968 insertions, 38 deletions
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index c61775e146..b1ecb0663f 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -280,6 +280,7 @@ filegroup(
"//tensorflow/contrib/linear_optimizer:all_files",
"//tensorflow/contrib/lookup:all_files",
"//tensorflow/contrib/losses:all_files",
+ "//tensorflow/contrib/meta_graph_transform:all_files",
"//tensorflow/contrib/metrics:all_files",
"//tensorflow/contrib/ndlstm:all_files",
"//tensorflow/contrib/nn:all_files",
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 5cd4753664..8a77f1c04a 100755
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -47,6 +47,7 @@ py_library(
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
"//tensorflow/contrib/memory_stats:memory_stats_py",
+ "//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/nccl:nccl_py",
"//tensorflow/contrib/ndlstm",
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index f32e7adaec..9da3b5d77b 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -240,6 +240,8 @@ add_python_module("tensorflow/python/training")
add_python_module("tensorflow/python/user_ops")
add_python_module("tensorflow/python/util")
add_python_module("tensorflow/python/util/protobuf")
+add_python_module("tensorflow/tools")
+add_python_module("tensorflow/tools/graph_transforms")
add_python_module("tensorflow/contrib")
add_python_module("tensorflow/contrib/android")
add_python_module("tensorflow/contrib/android/java")
@@ -440,6 +442,7 @@ add_python_module("tensorflow/contrib/memory_stats/ops")
add_python_module("tensorflow/contrib/memory_stats/python")
add_python_module("tensorflow/contrib/memory_stats/python/kernel_tests")
add_python_module("tensorflow/contrib/memory_stats/python/ops")
+add_python_module("tensorflow/contrib/meta_graph_transform")
add_python_module("tensorflow/contrib/metrics")
add_python_module("tensorflow/contrib/metrics/kernels")
add_python_module("tensorflow/contrib/metrics/ops")
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index c7d2ac7b56..216c10dc32 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -141,6 +141,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/debug/lib/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/debug/wrappers/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/*.py"
+ "${tensorflow_source_dir}/tensorflow/python/meta_graph_transform/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/profiler/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/profiler/internal/*_test.py"
"${tensorflow_source_dir}/tensorflow/python/saved_model/*_test.py"
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 01d769c80d..8e2ddea497 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -31,6 +31,7 @@ py_library(
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lookup:lookup_py",
"//tensorflow/contrib/losses:losses_py",
+ "//tensorflow/contrib/meta_graph_transform",
"//tensorflow/contrib/metrics:metrics_py",
"//tensorflow/contrib/rnn:rnn_py",
"//tensorflow/contrib/session_bundle:exporter",
diff --git a/tensorflow/contrib/learn/python/learn/estimators/__init__.py b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
index bba479a00e..4981750c94 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/__init__.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/__init__.py
@@ -303,6 +303,7 @@ from tensorflow.contrib.learn.python.learn.estimators.dnn_linear_combined import
from tensorflow.contrib.learn.python.learn.estimators.dynamic_rnn_estimator import DynamicRnnEstimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import BaseEstimator
from tensorflow.contrib.learn.python.learn.estimators.estimator import Estimator
+from tensorflow.contrib.learn.python.learn.estimators.estimator import GraphRewriteSpec
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input
from tensorflow.contrib.learn.python.learn.estimators.estimator import infer_real_valued_columns_from_input_fn
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator.py b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
index b9201cc805..c184b14654 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import abc
+import collections
import copy
import os
import tempfile
@@ -34,7 +35,6 @@ from tensorflow.contrib.framework import deprecated
from tensorflow.contrib.framework import deprecated_args
from tensorflow.contrib.framework import list_variables
from tensorflow.contrib.framework import load_variable
-from tensorflow.contrib.framework.python.ops import variables as contrib_variables
from tensorflow.contrib.learn.python.learn import evaluable
from tensorflow.contrib.learn.python.learn import metric_spec
from tensorflow.contrib.learn.python.learn import monitors as monitor_lib
@@ -49,6 +49,7 @@ from tensorflow.contrib.learn.python.learn.estimators._sklearn import NotFittedE
from tensorflow.contrib.learn.python.learn.learn_io import data_feeder
from tensorflow.contrib.learn.python.learn.utils import export
from tensorflow.contrib.learn.python.learn.utils import saved_model_export_utils
+from tensorflow.contrib.meta_graph_transform import meta_graph_transform
from tensorflow.contrib.training.python.training import evaluation
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
@@ -69,6 +70,7 @@ from tensorflow.python.training import device_setter
from tensorflow.python.training import monitored_session
from tensorflow.python.training import saver
from tensorflow.python.training import summary_io
+from tensorflow.python.training import training_util
from tensorflow.python.util import compat
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
@@ -346,12 +348,17 @@ def _write_dict_to_summary(output_dir,
value.simple_value = int(dictionary[key])
else:
logging.warn(
- 'Skipping summary for %s, must be a float, np.float32, np.int64, np.int32 or int.',
+ 'Skipping summary for %s, must be a float, np.float32, '
+ 'np.int64, np.int32 or int.',
key)
summary_writer.add_summary(summary_proto, current_global_step)
summary_writer.flush()
+GraphRewriteSpec = collections.namedtuple('GraphRewriteSpec',
+ ['tags', 'transforms'])
+
+
class BaseEstimator(
sklearn.BaseEstimator, evaluable.Evaluable, trainable.Trainable):
"""Abstract BaseEstimator class to train and evaluate TensorFlow models.
@@ -1229,7 +1236,8 @@ class Estimator(BaseEstimator):
default_output_alternative_key=None,
assets_extra=None,
as_text=False,
- checkpoint_path=None):
+ checkpoint_path=None,
+ graph_rewrite_specs=(GraphRewriteSpec((tag_constants.SERVING,), ()),)):
"""Exports inference graph as a SavedModel into given dir.
Args:
@@ -1249,6 +1257,10 @@ class Estimator(BaseEstimator):
as_text: whether to write the SavedModel proto in text format.
checkpoint_path: The checkpoint path to export. If None (the default),
the most recent checkpoint found within the model directory is chosen.
+ graph_rewrite_specs: an iterable of `GraphRewriteSpec`. Each element will
+ produce a separate MetaGraphDef within the exported SavedModel, tagged
+ and rewritten as specified. Defaults to a single entry using the
+ default serving tag ("serve") and no rewriting.
Returns:
The string path to the exported directory.
@@ -1259,8 +1271,20 @@ class Estimator(BaseEstimator):
if serving_input_fn is None:
raise ValueError('serving_input_fn must be defined.')
+ if not checkpoint_path:
+ # Locate the latest checkpoint
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise NotFittedError("Couldn't find trained model at %s."
+ % self._model_dir)
+
+ export_dir = saved_model_export_utils.get_timestamped_export_dir(
+ export_dir_base)
+ builder = saved_model_builder.SavedModelBuilder(export_dir)
+
+ # Build the base graph
with ops.Graph().as_default() as g:
- contrib_variables.create_global_step(g)
+ training_util.create_global_step(g)
# Call the serving_input_fn and collect the input alternatives.
input_ops = serving_input_fn()
@@ -1281,55 +1305,87 @@ class Estimator(BaseEstimator):
saved_model_export_utils.get_output_alternatives(
model_fn_ops, default_output_alternative_key))
+ init_op = control_flow_ops.group(
+ variables.local_variables_initializer(),
+ resources.initialize_resources(resources.shared_resources()),
+ lookup_ops.tables_initializer())
+
# Build the SignatureDefs from all pairs of input and output alternatives
signature_def_map = saved_model_export_utils.build_all_signature_defs(
input_alternatives, output_alternatives,
actual_default_output_alternative_key)
- if not checkpoint_path:
- # Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
- if not checkpoint_path:
- raise NotFittedError("Couldn't find trained model at %s."
- % self._model_dir)
+ # Export the first MetaGraphDef with variables, assets etc.
+ with tf_session.Session('') as session:
- export_dir = saved_model_export_utils.get_timestamped_export_dir(
- export_dir_base)
+ # pylint: disable=protected-access
+ saveables = variables._all_saveable_objects()
+ # pylint: enable=protected-access
+
+ if (model_fn_ops.scaffold is not None and
+ model_fn_ops.scaffold.saver is not None):
+ saver_for_restore = model_fn_ops.scaffold.saver
+ elif saveables:
+ saver_for_restore = saver.Saver(saveables, sharded=True)
- if (model_fn_ops.scaffold is not None and
- model_fn_ops.scaffold.saver is not None):
- saver_for_restore = model_fn_ops.scaffold.saver
- else:
- saver_for_restore = saver.Saver(sharded=True)
- with tf_session.Session('') as session:
saver_for_restore.restore(session, checkpoint_path)
- init_op = control_flow_ops.group(
- variables.local_variables_initializer(),
- resources.initialize_resources(resources.shared_resources()),
- lookup_ops.tables_initializer())
# Perform the export
- builder = saved_model_builder.SavedModelBuilder(export_dir)
+ if not graph_rewrite_specs or graph_rewrite_specs[0].transforms:
+ raise ValueError('The first element of graph_rewrite_specs '
+ 'must specify no transforms.')
+ untransformed_tags = graph_rewrite_specs[0].tags
+
+ # TODO(soergel): switch to main_op or otherwise update when dust settles
builder.add_meta_graph_and_variables(
- session, [tag_constants.SERVING],
+ session, untransformed_tags,
signature_def_map=signature_def_map,
assets_collection=ops.get_collection(
ops.GraphKeys.ASSET_FILEPATHS),
legacy_init_op=init_op)
- builder.save(as_text)
-
- # Add the extra assets
- if assets_extra:
- assets_extra_path = os.path.join(compat.as_bytes(export_dir),
- compat.as_bytes('assets.extra'))
- for dest_relative, source in assets_extra.items():
- dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
- compat.as_bytes(dest_relative))
- dest_path = os.path.dirname(dest_absolute)
- gfile.MakeDirs(dest_path)
- gfile.Copy(source, dest_absolute)
-
- return export_dir
+
+ # pylint: disable=protected-access
+ base_meta_graph_def = builder._saved_model.meta_graphs[0]
+ # pylint: enable=protected-access
+
+ if graph_rewrite_specs[1:]:
+ # Prepare the input_names and output_names needed for the
+ # meta_graph_transform call below.
+ input_names = [tensor.name
+ for input_dict in input_alternatives.values()
+ for tensor in input_dict.values()]
+ output_names = [tensor.name
+ for output_alternative in output_alternatives.values()
+ for tensor in output_alternative[1].values()]
+
+ # Write the additional MetaGraphDefs
+ for graph_rewrite_spec in graph_rewrite_specs[1:]:
+
+ # TODO(soergel) consider moving most of this to saved_model.builder_impl
+ # as e.g. builder.add_rewritten_meta_graph(rewritten_graph_def, tags)
+
+ transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
+ base_meta_graph_def, input_names, output_names,
+ graph_rewrite_spec.transforms, graph_rewrite_spec.tags)
+
+ # pylint: disable=protected-access
+ meta_graph_def = builder._saved_model.meta_graphs.add()
+ # pylint: enable=protected-access
+ meta_graph_def.CopyFrom(transformed_meta_graph_def)
+
+ # Add the extra assets
+ if assets_extra:
+ assets_extra_path = os.path.join(compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra'))
+ for dest_relative, source in assets_extra.items():
+ dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
+ compat.as_bytes(dest_relative))
+ dest_path = os.path.dirname(dest_absolute)
+ gfile.MakeDirs(dest_path)
+ gfile.Copy(source, dest_absolute)
+
+ builder.save(as_text)
+ return export_dir
# For time of deprecation x,y from Estimator allow direct access.
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index 54e6595aa8..855c44d518 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -1006,6 +1006,130 @@ class EstimatorTest(test.TestCase):
# cleanup
gfile.DeleteRecursively(tmpdir)
+ def test_export_savedmodel_with_graph_transforms(self):
+ tmpdir = tempfile.mkdtemp()
+ est, serving_input_fn = _build_estimator_for_export_tests(tmpdir)
+
+ extra_file_name = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('my_extra_file'))
+ extra_file = gfile.GFile(extra_file_name, mode='w')
+ extra_file.write(EXTRA_FILE_CONTENT)
+ extra_file.close()
+ assets_extra = {'some/sub/directory/my_extra_file': extra_file_name}
+
+ export_dir_base = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('export'))
+ export_dir = est.export_savedmodel(
+ export_dir_base, serving_input_fn, assets_extra=assets_extra,
+ graph_rewrite_specs=[
+ estimator.GraphRewriteSpec(['tag_1'], []),
+ estimator.GraphRewriteSpec(['tag_2', 'tag_3'],
+ ['strip_unused_nodes'])])
+
+ self.assertTrue(gfile.Exists(export_dir_base))
+ self.assertTrue(gfile.Exists(export_dir))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes(
+ 'saved_model.pb'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes('variables'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.index'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('variables/variables.data-00000-of-00001'))))
+
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes('assets'))))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))))
+ self.assertEqual(
+ compat.as_bytes(VOCAB_FILE_CONTENT),
+ compat.as_bytes(
+ gfile.GFile(
+ os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('assets/my_vocab_file'))).read()))
+
+ expected_extra_path = os.path.join(
+ compat.as_bytes(export_dir),
+ compat.as_bytes('assets.extra/some/sub/directory/my_extra_file'))
+ self.assertTrue(
+ gfile.Exists(
+ os.path.join(
+ compat.as_bytes(export_dir), compat.as_bytes('assets.extra'))))
+ self.assertTrue(gfile.Exists(expected_extra_path))
+ self.assertEqual(
+ compat.as_bytes(EXTRA_FILE_CONTENT),
+ compat.as_bytes(gfile.GFile(expected_extra_path).read()))
+
+ expected_vocab_file = os.path.join(
+ compat.as_bytes(tmpdir), compat.as_bytes('my_vocab_file'))
+
+ # Restore, to validate that the export was well-formed.
+ # tag_1 is untransformed.
+ tags = ['tag_1']
+ with ops.Graph().as_default() as graph:
+ with session_lib.Session(graph=graph) as sess:
+ loader.load(sess, tags, export_dir)
+ assets = [
+ x.eval()
+ for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
+ ]
+ self.assertItemsEqual([expected_vocab_file], assets)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('linear/linear/feature/matmul' in graph_ops)
+ # Since there were no transforms, both save ops are still present.
+ self.assertTrue('save/SaveV2/tensor_names' in graph_ops)
+ self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops)
+ # Since there were no transforms, the hash table lookup is still there.
+ self.assertTrue('hash_table_Lookup' in graph_ops)
+
+ # Restore, to validate that the export was well-formed.
+ # tag_2, tag_3 was subjected to strip_unused_nodes.
+ tags = ['tag_2', 'tag_3']
+ with ops.Graph().as_default() as graph:
+ with session_lib.Session(graph=graph) as sess:
+ loader.load(sess, tags, export_dir)
+ assets = [
+ x.eval()
+ for x in graph.get_collection(ops.GraphKeys.ASSET_FILEPATHS)
+ ]
+ self.assertItemsEqual([expected_vocab_file], assets)
+ graph_ops = [x.name for x in graph.get_operations()]
+ self.assertTrue('input_example_tensor' in graph_ops)
+ self.assertTrue('ParseExample/ParseExample' in graph_ops)
+ self.assertTrue('linear/linear/feature/matmul' in graph_ops)
+ # The Saver used to restore the checkpoint into the export Session
+ # was not added to the SAVERS collection, so strip_unused_nodes removes
+ # it. The one explicitly created in export_savedmodel is tracked in
+ # the MetaGraphDef saver_def field, so that one is retained.
+ # TODO(soergel): Make Savers sane again. I understand this is all a bit
+ # nuts but for now the test demonstrates what actually happens.
+ self.assertFalse('save/SaveV2/tensor_names' in graph_ops)
+ self.assertTrue('save_1/SaveV2/tensor_names' in graph_ops)
+ # The fake hash table lookup wasn't connected to anything; stripped.
+ self.assertFalse('hash_table_Lookup' in graph_ops)
+
+ # cleanup
+ gfile.DeleteRecursively(tmpdir)
+
class InferRealValuedColumnsTest(test.TestCase):
diff --git a/tensorflow/contrib/meta_graph_transform/BUILD b/tensorflow/contrib/meta_graph_transform/BUILD
new file mode 100644
index 0000000000..bae72aa76f
--- /dev/null
+++ b/tensorflow/contrib/meta_graph_transform/BUILD
@@ -0,0 +1,59 @@
+# Description:
+# Utility for applying the Graph Transform tool to a MetaGraphDef.
+
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "py_test",
+)
+
+py_library(
+ name = "meta_graph_transform",
+ srcs = [
+ "__init__.py",
+ "meta_graph_transform.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:ops",
+ "//tensorflow/python/saved_model:constants",
+ "//tensorflow/tools/graph_transforms:transform_graph_py",
+ ],
+)
+
+py_test(
+ name = "meta_graph_transform_test",
+ size = "small",
+ srcs = ["meta_graph_transform_test.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:private"],
+ deps = [
+ ":meta_graph_transform",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
+filegroup(
+ name = "py_srcs",
+ data = glob([
+ "**/*.py",
+ ]),
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/contrib/meta_graph_transform/__init__.py b/tensorflow/contrib/meta_graph_transform/__init__.py
new file mode 100644
index 0000000000..a58ac243a5
--- /dev/null
+++ b/tensorflow/contrib/meta_graph_transform/__init__.py
@@ -0,0 +1,28 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Utility for applying the Graph Transform tool to a MetaGraphDef."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.meta_graph_transform import meta_graph_transform
+from tensorflow.python.util.all_util import remove_undocumented
+
+
+_allowed_symbols = ['meta_graph_transform']
+
+remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
new file mode 100644
index 0000000000..1b88e1776f
--- /dev/null
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -0,0 +1,453 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Apply graph_transforms tool to MetaGraphDefs."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
+from tensorflow.python.client import session as _session
+from tensorflow.python.framework import graph_util as _graph_util
+from tensorflow.python.framework import importer as _importer
+from tensorflow.python.framework import ops as _ops
+from tensorflow.python.saved_model import constants as _saved_model_constants
+from tensorflow.python.training import saver as _saver_lib
+from tensorflow.python.util import compat
+from tensorflow.tools import graph_transforms as _graph_transforms
+
+
+def _op_name(tensor_name):
+ """Get the op name from a tensor name."""
+ # control dependency inputs start with ^
+ if tensor_name[0] == '^':
+ tensor_name = tensor_name[1:]
+ if ':' in tensor_name:
+ op_name, _ = tensor_name.split(':')
+ return op_name
+ return tensor_name
+
+
+def _do_transforms(graph_def, input_names, output_names, initializer_names,
+ transforms, saver_def=None, checkpoint_path=None):
+ """Apply requested transforms to a GraphDef, including freezing.
+
+ This applies the Graph Transform Tool interleaved with graph freezing.
+
+ Args:
+ graph_def: A GraphDef proto to be transformed.
+ input_names: Names of input nodes.
+ output_names: Names of output nodes.
+ initializer_names: Names of "infrastructural" nodes (initializers, save and
+ restore ops, etc.) that should be retained even if they are not
+ transitively reachable from output nodes.
+ transforms: A list of strings naming the graph transforms to be applied in
+ order. These transform names are exactly those supported by the Graph
+ Transform Tool, with the addition of the 'freeze_graph' transform.
+ saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
+ if needed (default None).
+ checkpoint_path: A path to a checkpoint to restore during freezing,
+ if needed (default None).
+ Returns:
+ The transformed GraphDef.
+ """
+ if not transforms:
+ transformed_graph_def = _graph_pb2.GraphDef()
+ transformed_graph_def.CopyFrom(graph_def)
+ return transformed_graph_def
+ else:
+ try:
+ freeze_index = transforms.index('freeze_graph')
+ except ValueError:
+ # No freeze_graph requested, so do all transforms in one go.
+ all_output_names = output_names + initializer_names
+ return _graph_transforms.TransformGraph(
+ graph_def, input_names, all_output_names, transforms)
+
+ # freeze_graph requested, possibly with transforms before and after.
+ phase_1_transforms = transforms[:freeze_index]
+ phase_2_transforms = transforms[freeze_index+1:]
+
+ graph_def = _do_transforms(
+ graph_def, input_names, output_names, initializer_names,
+ phase_1_transforms, saver_def, checkpoint_path)
+ output_node_names = [_op_name(x) for x in output_names]
+ graph_def = _freeze_graph_with_def_protos(
+ graph_def, output_node_names, saver_def, checkpoint_path)
+ # No need for saver or checkpoint anymore
+ return _do_transforms(
+ graph_def, input_names, output_names, [], phase_2_transforms)
+
+
+# forked and modified from freeze_graph.py
+def _freeze_graph_with_def_protos(
+ input_graph_def,
+ output_node_names,
+ input_saver_def,
+ input_checkpoint):
+ """Converts all variables in a graph and checkpoint into constants."""
+
+ with _ops.Graph().as_default():
+ _ = _importer.import_graph_def(input_graph_def, name='')
+
+ with _session.Session() as sess:
+ saver = _saver_lib.Saver(saver_def=input_saver_def)
+ saver.restore(sess, input_checkpoint)
+ output_graph_def = _graph_util.convert_variables_to_constants(
+ sess, input_graph_def, output_node_names)
+
+ return output_graph_def
+
+
+def _find_all_mandatory_retain_ops(base_meta_graph_def):
+ """Identify all infrastructural Ops, to ensure that they are retained.
+
+ We need to retain infrastructural Ops (init and saver stuff), in addition
+ to the desired outputs.
+
+ For now we retain *all* save and restore ops, variable initializers,
+ table initializers, and main init ops.
+ This means that strip_unused_nodes will not remove unused variables.
+
+ Args:
+ base_meta_graph_def: a GraphDef proto in which to identify nodes to retain.
+
+ Returns:
+ A list of node names to be retained.
+ """
+ # TODO(b/63447631): implement variable stripping.
+
+ initializer_names = []
+
+ # Primary SaverDef and SAVERS collection
+ saver_defs = []
+ if base_meta_graph_def.HasField('saver_def'):
+ saver_defs.append(base_meta_graph_def.saver_def)
+ saver_defs.extend(_get_all_protos_from_collection(
+ base_meta_graph_def, _ops.GraphKeys.SAVERS))
+ for saver_def in saver_defs:
+ initializer_names.append(saver_def.filename_tensor_name)
+ initializer_names.append(saver_def.save_tensor_name)
+ initializer_names.append(saver_def.restore_op_name)
+
+ # Variable initializers
+ variable_collections = [
+ _ops.GraphKeys.GLOBAL_VARIABLES,
+ _ops.GraphKeys.TRAINABLE_VARIABLES,
+ _ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
+ _ops.GraphKeys.LOCAL_VARIABLES,
+ _ops.GraphKeys.MODEL_VARIABLES]
+ for var_coll in variable_collections:
+ variables = _get_all_protos_from_collection(base_meta_graph_def, var_coll)
+ var_init_names = [v.initializer_name for v in variables]
+ if var_init_names:
+ initializer_names.extend(var_init_names)
+
+ # Table initializers
+ op_names = _get_all_node_names_from_collection(
+ base_meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
+ if op_names:
+ initializer_names.extend(op_names)
+
+ # Various init ops
+ various_init_op_collections = [_saved_model_constants.LEGACY_INIT_OP_KEY,
+ _saved_model_constants.MAIN_OP_KEY,
+ _ops.GraphKeys.INIT_OP,
+ _ops.GraphKeys.LOCAL_INIT_OP,
+ _ops.GraphKeys.READY_OP,
+ _ops.GraphKeys.READY_FOR_LOCAL_INIT_OP]
+ for op_coll in various_init_op_collections:
+ op_name = _get_single_node_name_from_collection(
+ base_meta_graph_def, op_coll)
+ if op_name:
+ initializer_names.append(op_name)
+
+ return initializer_names
+
+
+def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
+ collection_name, removed_op_names):
+ """Copy collection to the transformed MetaGraphDef, omitting removed items."""
+
+ base_collection = base_meta_graph_def.collection_def[collection_name]
+ collection = meta_graph_def.collection_def[collection_name]
+
+ if base_collection.HasField('any_list'):
+ for any_value in base_collection.any_list.value:
+ # just search the serialized proto as a string
+ if not _is_removed_mentioned(any_value.value, removed_op_names):
+ copied_any = collection.any_list.value.add()
+ copied_any.CopyFrom(any_value)
+ elif base_collection.HasField('bytes_list'):
+ collection.bytes_list.value[:] = [
+ s for s in base_collection.bytes_list.value
+ if not _is_removed_mentioned(s, removed_op_names)]
+ elif base_collection.HasField('node_list'):
+ collection.node_list.value[:] = [
+ s for s in base_collection.node_list.value
+ if not _is_removed(s, removed_op_names)]
+ else:
+ collection.CopyFrom(base_collection)
+
+
+def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names):
+ """Copy the Saver into the transformed MetaGraphDef, if valid.
+
+ Currently this copies the Saver as is, after verifying that none of the
+ referenced Save & Restore ops were removed. A future version will modify
+ the Save and Restore ops themselves as needed to account for removed
+ Variables.
+
+ Args:
+ base_meta_graph_def: The untransformed MetaGraphDef.
+ meta_graph_def: The transformed MetaGraphDef being built.
+ removed_op_names: An iterable of names of ops that were removed.
+ """
+
+ # Note this does surgery on meta_graph_def.graph_def too, so that should have
+ # been copied already.
+ if base_meta_graph_def.HasField('saver_def'):
+ filename_tensor_name = base_meta_graph_def.saver_def.filename_tensor_name
+ save_tensor_name = base_meta_graph_def.saver_def.save_tensor_name
+ restore_op_name = base_meta_graph_def.saver_def.restore_op_name
+
+ _check_tensor_not_removed(filename_tensor_name, removed_op_names)
+ _check_tensor_not_removed(save_tensor_name, removed_op_names)
+ _check_tensor_not_removed(restore_op_name, removed_op_names)
+
+ # TODO(b/63447631): Once we strip unused variables, remove references to
+ # them from save and restore ops. Retain those ops only if they also refer
+ # to retained Variables.
+
+ # saver_name, restore_all = restore_op_name.rsplit('/', 1)
+ # if restore_all != 'restore_all':
+ # raise ValueError(
+ # 'SaverDef restore_op_name did not have expected form */restore_all')
+
+ # save_tensor_names_op_name = '{}/SaveV2/tensor_names'.format(saver_name)
+ # restore_tensor_names_op_name = (
+ # '{}/RestoreV2/tensor_names'.format(saver_name))
+
+ # save_tensor_names_op = _find_op(meta_graph_def.graph_def,
+ # save_tensor_names_op_name)
+ # save_tensor_names_value_tensor = save_tensor_names_op.attr['value'].tensor
+ # save_tensor_names_value_tensor.string_val[:] = [
+ # s for s in save_tensor_names_value_tensor.string_val
+ # if not _is_removed(s, removed_op_names)]
+
+ # restore_tensor_names_op = _find_op(
+ # meta_graph_def.graph_def, restore_tensor_names_op_name)
+ # restore_tensor_names_value_tensor = (
+ # restore_tensor_names_op.attr['value'].tensor)
+ # restore_tensor_names_value_tensor.string_val[:] = [
+ # s for s in restore_tensor_names_value_tensor.string_val
+ # if not _is_removed(s, removed_op_names)]
+
+ # if (save_tensor_names_value_tensor.string_val
+ # or restore_tensor_names_value_tensor.string_val):
+ meta_graph_def.saver_def.CopyFrom(base_meta_graph_def.saver_def)
+
+
+def _find_op(graph_def, op_name):
+ """Fetch a node from a GraphDef proto by name."""
+ for node_def in graph_def.node:
+ if node_def.name == op_name:
+ return node_def
+ return None
+
+
+def _add_pruned_signature(base_meta_graph_def, meta_graph_def,
+ signature_name, removed_op_names):
+ """Copy the named signature into the transformed MetaGraphDef, if valid.
+
+ If any input or output mentioned in the signature was removed by the graph
+ transform, the signature is silently omitted from the transformed
+ MetaGraphDef.
+
+ Args:
+ base_meta_graph_def: The untransformed MetaGraphDef.
+ meta_graph_def: The transformed MetaGraphDef being built.
+ signature_name: The name of the signature to copy.
+ removed_op_names: An iterable of names of ops that were removed.
+ """
+ try:
+ base_signature = base_meta_graph_def.signature_def[signature_name]
+ for key in base_signature.inputs:
+ _check_tensor_not_removed(base_signature.inputs[key].name,
+ removed_op_names)
+ for key in base_signature.outputs:
+ _check_tensor_not_removed(base_signature.outputs[key].name,
+ removed_op_names)
+ meta_graph_def.signature_def[signature_name].CopyFrom(base_signature)
+ except ValueError:
+ # exclude any signature that mentions a removed node
+ pass
+
+
+def _get_single_node_name_from_collection(meta_graph_def, collection_key):
+ """Obtain a node name that is the single element of a collection."""
+ if collection_key not in meta_graph_def.collection_def:
+ return None
+ collection = meta_graph_def.collection_def[collection_key]
+ if not collection.node_list.value:
+ raise ValueError(
+ 'Collection {} is present but type is not node_list.'.format(
+ collection_key))
+ if len(collection.node_list.value) != 1:
+ raise ValueError(
+ 'Collection {} is has {} elements; expected exactly one.'.format(
+ collection_key, collection.bytes_list))
+ return collection.node_list.value[0]
+
+
+def _get_all_node_names_from_collection(meta_graph_def, collection_key):
+ """Obtain node names from a collection."""
+ if collection_key not in meta_graph_def.collection_def:
+ return None
+ collection = meta_graph_def.collection_def[collection_key]
+ if not collection.node_list.value:
+ raise ValueError(
+ 'Collection {} is present but type is not node_list.'.format(
+ collection_key))
+ return collection.node_list.value
+
+
+def _get_all_protos_from_collection(meta_graph_def, collection_key):
+ """Obtain node names from a collection."""
+ if collection_key not in meta_graph_def.collection_def:
+ return []
+ collection = meta_graph_def.collection_def[collection_key]
+ if not collection.bytes_list.value:
+ raise ValueError(
+ 'Collection {} is present but type is not bytes_list.'.format(
+ collection_key))
+ proto_type = _ops.get_collection_proto_type(collection_key)
+ result = []
+ for value in collection.bytes_list.value:
+ proto = proto_type()
+ proto.ParseFromString(value)
+ result.append(proto)
+ return result
+
+
+def _is_removed(tensor_name, removed_op_names):
+ """Determine whether the named tensor is an output of a removed op."""
+ for removed_op_name in removed_op_names:
+ if tensor_name.startswith(removed_op_name):
+ return True
+ return False
+
+
+def _is_removed_mentioned(s, removed_op_names):
+ """Determine whether any removed op is mentioned in the given object.
+
+ This relies on the string representation of the object. This is used for
+ proto messages that may mention ops by name in nested fields. The string
+ representation of the proto includes those field values, so this string
+ search approach is sufficient.
+
+ Args:
+ s: an object to search for removed op names.
+ removed_op_names: An iterable of names of ops that were removed.
+
+ Returns:
+ True if any removed op is mentioned in the given object, False otherwise.
+ """
+ for removed_op_name in removed_op_names:
+ if removed_op_name in compat.as_str_any(s):
+ return True
+ return False
+
+
+def _check_tensor_not_removed(tensor_name, removed_op_names):
+ """Verify that the named tensor was not removed.
+
+ Args:
+ tensor_name: the name of a tensor to check.
+ removed_op_names: An iterable of names of ops that were removed.
+
+ Raises:
+ ValueError: if the tensor was removed.
+ """
+ if not tensor_name:
+ raise ValueError('Tensor name should not be empty')
+ if _is_removed(tensor_name, removed_op_names):
+ raise ValueError(
+ 'Expected Tensor, but it was removed: {}'.format(tensor_name))
+
+
+def meta_graph_transform(
+ base_meta_graph_def, input_names, output_names, transforms, tags,
+ checkpoint_path=None):
+ """Apply the Graph Transform tool to a MetaGraphDef.
+
+ Args:
+ base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
+ input_names: Names of input nodes.
+ output_names: Names of output nodes.
+ transforms: A list of strings naming the graph transforms to be applied in
+ order. These transform names are exactly those supported by the Graph
+ Transform Tool, with the addition of the 'freeze_graph' transform.
+ tags: A list of tags with which to annotate the transformed MetaGraphDef.
+ checkpoint_path: A path to a checkpoint to restore during freezing,
+ if needed (default None).
+
+ Returns:
+ A new transformed MetaGraphDef protocol buffer.
+ """
+ meta_graph_def = _meta_graph_pb2.MetaGraphDef()
+
+ initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)
+
+ transformed_graph_def = _do_transforms(
+ base_meta_graph_def.graph_def,
+ input_names,
+ output_names,
+ initializer_names,
+ transforms,
+ base_meta_graph_def.saver_def,
+ checkpoint_path)
+
+ meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
+ meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
+ meta_graph_def.meta_info_def.ClearField('tags')
+ for tag in tags:
+ meta_graph_def.meta_info_def.tags.append(tag)
+
+ base_op_names = [compat.as_str(node.name)
+ for node in base_meta_graph_def.graph_def.node]
+ retained_op_names = [compat.as_str(node.name)
+ for node in meta_graph_def.graph_def.node]
+ removed_op_names = set(base_op_names) - set(retained_op_names)
+
+ # Copy saver, excluding any pruned nodes
+ _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)
+
+ # Copy collections, excluding any pruned nodes
+ for collection_name in base_meta_graph_def.collection_def:
+ _add_pruned_collection(
+ base_meta_graph_def, meta_graph_def, collection_name,
+ removed_op_names)
+
+ # Copy signature_defs, excluding any pruned nodes
+ for signature_name in base_meta_graph_def.signature_def:
+ _add_pruned_signature(
+ base_meta_graph_def, meta_graph_def, signature_name,
+ removed_op_names)
+
+ return meta_graph_def
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py
new file mode 100644
index 0000000000..51f01df65b
--- /dev/null
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py
@@ -0,0 +1,202 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MetaGraphDef Transform Tool."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from google.protobuf.any_pb2 import Any
+from tensorflow.contrib.meta_graph_transform import meta_graph_transform
+from tensorflow.core.framework import function_pb2
+from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver
+from tensorflow.python.util import compat
+
+
+def _make_asset_file_def_any(node_name):
+ asset_file_def = meta_graph_pb2.AssetFileDef()
+ asset_file_def.tensor_info.name = node_name
+ any_message = Any()
+ any_message.Pack(asset_file_def)
+ return any_message
+
+
+class MetaGraphTransformTest(test.TestCase):
+
+ def test_meta_graph_transform(self):
+
+ with ops.Graph().as_default():
+ with tf_session.Session(''):
+ a = array_ops.placeholder(dtypes.int64, [1], name='a')
+ b = array_ops.placeholder(dtypes.int64, [1], name='b')
+ c = array_ops.placeholder(dtypes.int64, [1], name='c')
+ _ = a * b
+ _ = b * c
+ base_meta_graph_def = saver.export_meta_graph()
+
+ with ops.Graph().as_default():
+ with tf_session.Session(''):
+ a = array_ops.placeholder(dtypes.int64, [1], name='a')
+ b = array_ops.placeholder(dtypes.int64, [1], name='b')
+ _ = a * b
+ meta_info_def = meta_graph_pb2.MetaGraphDef.MetaInfoDef()
+ meta_info_def.tags.append('tag_ab')
+
+ expected_meta_graph_def = saver.export_meta_graph(
+ meta_info_def=meta_info_def)
+ # Graph rewriter clears versions field, so we expect that.
+ expected_meta_graph_def.graph_def.ClearField('versions')
+ # Graph rewriter adds an empty library field, so we expect that.
+ expected_meta_graph_def.graph_def.library.CopyFrom(
+ function_pb2.FunctionDefLibrary())
+
+ input_names = ['a', 'b']
+ output_names = ['mul:0']
+ transforms = ['strip_unused_nodes']
+ tags = ['tag_ab']
+ print('AAAAAA: {}'.format(base_meta_graph_def))
+ transformed_meta_graph_def = meta_graph_transform.meta_graph_transform(
+ base_meta_graph_def, input_names, output_names, transforms, tags)
+
+ self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
+
+ def test_add_pruned_collection_node(self):
+ collection_name = 'node_collection'
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ base_meta_graph_def.collection_def[collection_name].node_list.value.extend(
+ ['node1', 'node2', 'node3', 'node4'])
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_collection(
+ base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
+
+ collection = meta_graph_def.collection_def[collection_name]
+
+ expected_nodes = ['node1', 'node3']
+ self.assertEqual(expected_nodes, collection.node_list.value)
+
+ def test_add_pruned_collection_int(self):
+ collection_name = 'int_collection'
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ base_meta_graph_def.collection_def[collection_name].int64_list.value[:] = (
+ [10, 20, 30, 40])
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_collection(
+ base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
+
+ collection = meta_graph_def.collection_def[collection_name]
+
+ expected_ints = [10, 20, 30, 40]
+ self.assertEqual(expected_ints, collection.int64_list.value)
+
+ def test_add_pruned_collection_proto_in_any_list(self):
+ collection_name = 'proto_collection'
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ base_meta_graph_def.collection_def[collection_name].any_list.value.extend(
+ [_make_asset_file_def_any('node1'),
+ _make_asset_file_def_any('node2'),
+ _make_asset_file_def_any('node3'),
+ _make_asset_file_def_any('node4')])
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_collection(
+ base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
+
+ collection = meta_graph_def.collection_def[collection_name]
+
+ expected_protos = [_make_asset_file_def_any('node1'),
+ _make_asset_file_def_any('node3')]
+ self.assertEqual(expected_protos, collection.any_list.value[:])
+
+ def test_add_pruned_collection_proto_in_bytes_list(self):
+ collection_name = 'proto_collection'
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
+ [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))])
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_collection(
+ base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
+
+ collection = meta_graph_def.collection_def[collection_name]
+
+ expected_values = [
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))]
+ self.assertEqual(expected_values, collection.bytes_list.value[:])
+
+ def test_add_pruned_saver(self):
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+
+ base_meta_graph_def.saver_def.filename_tensor_name = 'node1'
+ base_meta_graph_def.saver_def.save_tensor_name = 'node3'
+ base_meta_graph_def.saver_def.restore_op_name = 'node6'
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_saver(base_meta_graph_def,
+ meta_graph_def,
+ removed_op_names)
+
+ # TODO(b/63447631): For now the saver is just copied unchanged
+ self.assertEqual(base_meta_graph_def.saver_def, meta_graph_def.saver_def)
+
+ def test_add_pruned_signature(self):
+ base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
+
+ signature_name_keep = 'test_signature_keep'
+ base_sig_keep = base_meta_graph_def.signature_def[signature_name_keep]
+ base_sig_keep.inputs['input_1'].name = 'input_1'
+ base_sig_keep.outputs['output_1'].name = 'output_1'
+
+ signature_name_remove = 'test_signature_remove'
+ base_sig_remove = base_meta_graph_def.signature_def[signature_name_remove]
+ base_sig_remove.inputs['node2'].name = 'node2'
+ base_sig_remove.outputs['output_1'].name = 'output_1'
+
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+ removed_op_names = ['node2', 'node4', 'node5']
+ meta_graph_transform._add_pruned_signature(base_meta_graph_def,
+ meta_graph_def,
+ signature_name_keep,
+ removed_op_names)
+ meta_graph_transform._add_pruned_signature(base_meta_graph_def,
+ meta_graph_def,
+ signature_name_remove,
+ removed_op_names)
+
+ self.assertTrue(signature_name_keep in meta_graph_def.signature_def)
+ sig_keep = meta_graph_def.signature_def[signature_name_keep]
+ self.assertEqual(base_sig_keep, sig_keep)
+
+ self.assertFalse(signature_name_remove in meta_graph_def.signature_def)
+
+
+if __name__ == '__main__':
+ test.main()