aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/meta_graph_transform
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-09-16 16:28:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-16 16:32:17 -0700
commit20e12d9c160ca2c5a20d238fba2d54e1e16741a5 (patch)
tree3a1f004e4c3b02fd976fb8f7e86777d6560e246b /tensorflow/contrib/meta_graph_transform
parentc2749f90b08314c3ae47289ebe803a28f601ad49 (diff)
Enabling sparsify_gather to be called before freeze.
PiperOrigin-RevId: 168971905
Diffstat (limited to 'tensorflow/contrib/meta_graph_transform')
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py351
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py333
2 files changed, 609 insertions, 75 deletions
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index e7849d49a7..72494f54f5 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -20,6 +20,8 @@ from __future__ import division
from __future__ import print_function
+import re
+
from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.python.client import session as _session
@@ -32,7 +34,8 @@ from tensorflow.python.util import compat
from tensorflow.tools import graph_transforms as _graph_transforms
-_FREEZE_GRAPH_TRANSFORM_NAME = 'freeze_graph'
+_FREEZE_GRAPH_TRANSFORM = 'freeze_graph'
+_SPARSIFY_GATHER_TRANSFORM = 'sparsify_gather'
def _op_name(tensor_name):
@@ -46,11 +49,27 @@ def _op_name(tensor_name):
return tensor_name
-def _do_transforms(graph_def, input_names, output_names, initializer_names,
- transforms, saver_def=None, checkpoint_path=None):
- """Apply requested transforms to a GraphDef, including freezing.
+def _get_shared_init_op(initializer_names):
+ """Obtain the shared init op name, if it exists.
+
+ Args:
+ initializer_names: Dictionary of the "infrastructural" nodes (initializers,
+ save and restore ops, etc.). The keys in this dictionary
+ indicate the collection where these nodes were obtained from.
+
+ Returns:
+ A string indicating the shared init op name or none if None if none exists.
+ """
+ return_value = initializer_names.get(_saved_model_constants.MAIN_OP_KEY, None)
+ if not return_value:
+ return_value = initializer_names.get(
+ _saved_model_constants.LEGACY_INIT_OP_KEY, None)
+ return str(return_value[0]) if return_value else None
+
- This applies the Graph Transform Tool interleaved with graph freezing.
+def _gtt_transforms(graph_def, input_names, output_names, initializer_names,
+ transforms):
+ """Pass through gtt transforms, applying them to the graph_def.
Args:
graph_def: A GraphDef proto to be transformed.
@@ -61,12 +80,7 @@ def _do_transforms(graph_def, input_names, output_names, initializer_names,
transitively reachable from output nodes. The keys in this dictionary
indicate the collection where these nodes were obtained from.
transforms: A list of strings naming the graph transforms to be applied in
- order. These transform names are exactly those supported by the Graph
- Transform Tool, with the addition of the 'freeze_graph' transform.
- saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
- if needed (default None).
- checkpoint_path: A path to a checkpoint to restore during freezing,
- if needed (default None).
+ order.
Returns:
The transformed GraphDef.
"""
@@ -74,42 +88,233 @@ def _do_transforms(graph_def, input_names, output_names, initializer_names,
transformed_graph_def = _graph_pb2.GraphDef()
transformed_graph_def.CopyFrom(graph_def)
return transformed_graph_def
- else:
- try:
- freeze_index = transforms.index(_FREEZE_GRAPH_TRANSFORM_NAME)
- except ValueError:
- # No freeze_graph requested, so do all transforms in one go.
- initializer_names_flat = sorted(
- [k for l in initializer_names.values() for k in l])
- all_output_names = output_names + initializer_names_flat
- return _graph_transforms.TransformGraph(
- graph_def, input_names, all_output_names, transforms)
-
- # freeze_graph requested, possibly with transforms before and after.
- phase_1_transforms = transforms[:freeze_index]
- phase_2_transforms = transforms[freeze_index+1:]
-
- graph_def = _do_transforms(
- graph_def, input_names, output_names, initializer_names,
- phase_1_transforms, saver_def, checkpoint_path)
- output_node_names = [_op_name(x) for x in output_names]
- graph_def = _freeze_graph_with_def_protos(
- graph_def, output_node_names,
- initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS],
- initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY][0],
- saver_def, checkpoint_path)
- # No need for saver or checkpoint anymore
- pruned_initializer_names = {}
- # Freeze graph will prune all initializers and shared init nodes if table
- # initializers are not present. Handle this case in future GTT transforms.
- if initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS]:
- pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = (
- initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS])
+
+ initializer_names_flat = sorted(
+ [k for l in initializer_names.values() for k in l])
+ all_output_names = output_names + initializer_names_flat
+ return _graph_transforms.TransformGraph(graph_def, input_names,
+ all_output_names, transforms)
+
+
+def _freeze_transform(graph_def, output_names, initializer_names, saver_def,
+ checkpoint_path):
+ """Handle the freeze transform.
+
+ Determine which initializer nodes should be retained by the freeze transform.
+ Retain those nodes and return an updated dictionary containing them.
+
+ Args:
+ graph_def: A GraphDef proto to be transformed.
+ output_names: Names of output nodes.
+ initializer_names: Dictionary of the "infrastructural" nodes (initializers,
+ save and restore ops, etc.). The keys in this dictionary
+ indicate the collection where these nodes were obtained from.
+ saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
+ if needed (default None).
+ checkpoint_path: A path to a checkpoint to restore during freezing,
+ if needed (default None).
+
+ Returns:
+ A tuple containing the GraphDef and a Dict of pruned initializer nodes.
+ """
+ table_initializers = initializer_names.get(_ops.GraphKeys.TABLE_INITIALIZERS,
+ [])
+ shared_init_op = _get_shared_init_op(initializer_names)
+
+ graph_def = _freeze_graph_with_def_protos(graph_def, output_names,
+ table_initializers, shared_init_op,
+ saver_def, checkpoint_path)
+ pruned_initializer_names = {}
+ # Freeze graph prunes all initializers and shared init nodes that are not
+ # explicitly maintained. Create new initializer_names dictionary to reflect
+ # this.
+ if table_initializers:
+ pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = (
+ table_initializers)
+ if _saved_model_constants.LEGACY_INIT_OP_KEY in initializer_names:
pruned_initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY] = (
initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY])
+ if _saved_model_constants.MAIN_OP_KEY in initializer_names:
+ pruned_initializer_names[_saved_model_constants.MAIN_OP_KEY] = (
+ initializer_names[_saved_model_constants.MAIN_OP_KEY])
+ return (graph_def, pruned_initializer_names)
+
+
+def _clean_save_and_restore(graph_def, op, removed_op_names):
+ """Clean the specified save and restore op.
+
+ Updates the dtypes attribute of the save / restore op and the associated name
+ and shape tensors to remove entries for variables that have been removed.
+
+ Args:
+ graph_def: A GraphDef proto to be transformed.
+ op: The save or restore op to update.
+ removed_op_names: List of op names that have been removed.
+ """
+ name = op.name + '/tensor_names'
+ shape = op.name + '/shape_and_slices'
+ name_op = _find_op(graph_def, name)
+ shape_op = _find_op(graph_def, shape)
+ name_op_value_tensor = name_op.attr['value'].tensor
+ shape_op_value_tensor = shape_op.attr['value'].tensor
+ names = []
+ shapes = []
+ dtypes = []
+ for index, value in enumerate(name_op_value_tensor.string_val):
+ if not _is_removed(compat.as_str(value), removed_op_names):
+ names.append(value)
+ shapes.append(shape_op_value_tensor.string_val[index])
+ dtypes.append(op.attr['dtypes'].list.type[index])
+ name_op_value_tensor.string_val[:] = names
+ name_op_value_tensor.tensor_shape.dim[0].size = len(names)
+ shape_op_value_tensor.string_val[:] = shapes
+ shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
+ op.attr['dtypes'].list.type[:] = dtypes
+
+ name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)
+ shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)
+
+
+def _sparsify_gather_transform(graph_def, input_names, output_names,
+ initializer_names, checkpoint_path):
+ """Handle the sparsify gather transform.
+
+ Provides the transform the checkpoint and keeps track of the newly created
+ initializer nodes.
+
+ Args:
+ graph_def: A GraphDef proto to be transformed.
+ input_names: Names of input nodes.
+ output_names: Names of output nodes.
+ initializer_names: Dictionary of the "infrastructural" nodes (initializers,
+ save and restore ops, etc.). The keys in this dictionary
+ indicate the collection where these nodes were obtained from.
+ checkpoint_path: A path to a checkpoint.
+
+ Returns:
+ A tuple containing the GraphDef and a Dict of updated initializer nodes.
+ Raises:
+ ValueError: if the restore_op_name does not have the expected format.
+ """
+ # Ensure that sparsify_shared_init_op is unique.
+ sparsify_shared_init_op = 'sparify_gather_init_op'
+ while _find_op(graph_def, sparsify_shared_init_op):
+ sparsify_shared_init_op += '_1'
+
+ input_flag = ''
+ if checkpoint_path:
+ input_flag = 'input_checkpoint="%s", ' % checkpoint_path
+
+ sparsify_cmd = [
+ 'sparsify_gather(%sgroup_init_node="%s")' % (input_flag,
+ sparsify_shared_init_op)
+ ]
+
+ starting_op_names = [node.name for node in graph_def.node]
+
+ graph_def = _gtt_transforms(graph_def, input_names, output_names,
+ initializer_names, sparsify_cmd)
+ ending_op_names = [node.name for node in graph_def.node]
+ removed_op_names = list(set(starting_op_names) - set(ending_op_names))
+ removed_op_names.sort()
+
+ for op_index, op_name in enumerate(removed_op_names):
+ op_name_parts = op_name.rsplit('/', 1)
+ # Remove part to get the checkpoint names used by the saver.
+ if len(op_name_parts) == 2 and op_name_parts[1].startswith('part_'):
+ removed_op_names[op_index] = op_name_parts[0]
+ else:
+ removed_op_names[op_index] = op_name
+
+ # Obtain newly created table inits from gtt sparsify transform.
+ added_table_inits = []
+ for index, node in enumerate(graph_def.node):
+ if node.name == sparsify_shared_init_op:
+ added_table_inits = [n.lstrip('^') for n in node.input]
+
+ table_initializers = initializer_names.get(
+ _ops.GraphKeys.TABLE_INITIALIZERS, [])
+ table_initializers.extend(added_table_inits)
+ initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers
+
+ del graph_def.node[index]
+ break
+
+ # Add inits to existing shared init op.
+ node = _find_op(graph_def, _get_shared_init_op(initializer_names))
+ for init in added_table_inits:
+ node.input.append('^' + init)
+
+ # Update saver.
+ for node in graph_def.node:
+ if node.name.endswith('SaveV2'):
+ _clean_save_and_restore(graph_def, node, removed_op_names)
+
+ return (graph_def, initializer_names)
+
+
+def _do_transforms(graph_def,
+ input_names,
+ output_names,
+ initializer_names,
+ transforms,
+ saver_def=None,
+ checkpoint_path=None):
+ """Apply requested transforms to a GraphDef, including freezing.
+
+ Args:
+ graph_def: A GraphDef proto to be transformed.
+ input_names: Names of input nodes.
+ output_names: Names of output nodes.
+ initializer_names: Dictionary of the "infrastructural" nodes (initializers,
+ save and restore ops, etc.) that should be retained even if they are not
+ transitively reachable from output nodes. The keys in this dictionary
+ indicate the collection where these nodes were obtained from.
+ transforms: A list of strings naming the graph transforms to be applied in
+ order. These transform names are exactly those supported by the Graph
+ Transform Tool, with the addition of the 'freeze_graph' and
+ 'sparsify_gather' transforms.
+ saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
+ if needed (default None).
+ checkpoint_path: A path to a checkpoint to restore during freezing,
+ if needed (default None).
+ Returns:
+ A tuple containing the GraphDef and a Dict of updated initializer nodes.
+ """
+ transformed_graph_def = _graph_pb2.GraphDef()
+ transformed_graph_def.CopyFrom(graph_def)
+ transformed_initializer_names = initializer_names.copy()
- return _do_transforms(graph_def, input_names, output_names,
- pruned_initializer_names, phase_2_transforms)
+ if not transforms:
+ return transformed_graph_def, transformed_initializer_names
+
+ current_gtt_transforms = []
+ for t in transforms:
+ if t == _FREEZE_GRAPH_TRANSFORM:
+ transformed_graph_def = _gtt_transforms(
+ transformed_graph_def, input_names, output_names,
+ transformed_initializer_names, current_gtt_transforms)
+ output_node_names = [_op_name(x) for x in output_names]
+ transformed_graph_def, transformed_initializer_names = _freeze_transform(
+ transformed_graph_def, output_node_names,
+ transformed_initializer_names, saver_def, checkpoint_path)
+ current_gtt_transforms = []
+ elif t == _SPARSIFY_GATHER_TRANSFORM:
+ transformed_graph_def = _gtt_transforms(
+ transformed_graph_def, input_names, output_names,
+ transformed_initializer_names, current_gtt_transforms)
+ transformed_graph_def, transformed_initializer_names = (
+ _sparsify_gather_transform(
+ transformed_graph_def, input_names, output_names,
+ transformed_initializer_names, checkpoint_path))
+ current_gtt_transforms = []
+ else:
+ current_gtt_transforms.append(t)
+
+ transformed_graph_def = _gtt_transforms(
+ transformed_graph_def, input_names, output_names,
+ transformed_initializer_names, current_gtt_transforms)
+ return transformed_graph_def, transformed_initializer_names
def _connect_to_shared_init_op(graph_def, shared_init_op_name,
@@ -296,7 +501,8 @@ def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names):
# TODO(b/63447631): Once we strip unused variables, remove references to
# them from save and restore ops. Retain those ops only if they also refer
- # to retained Variables.
+ # to retained Variables. See if we can use _clean_save_and_restore() for
+ # this.
# saver_name, restore_all = restore_op_name.rsplit('/', 1)
# if restore_all != 'restore_all':
@@ -412,7 +618,7 @@ def _get_all_protos_from_collection(meta_graph_def, collection_key):
def _is_removed(tensor_name, removed_op_names):
"""Determine whether the named tensor is an output of a removed op."""
for removed_op_name in removed_op_names:
- if tensor_name.startswith(removed_op_name):
+ if tensor_name.split(':')[0] == removed_op_name:
return True
return False
@@ -432,9 +638,17 @@ def _is_removed_mentioned(s, removed_op_names):
Returns:
True if any removed op is mentioned in the given object, False otherwise.
"""
+ # A common approach taken by some of the transforms in gtt is to add new nodes
+ # that have the same prefix as the node they are removing. For example, if
+ # the original node name was /foo, they may remove that node and add in
+ # /foo/bar. This regex ensures that we handle these two nodes
+ # as separate entities. It matches on nodes having names in the form of
+ # '/foo/bar_x' as well as nodes having names in the form of 'foo.'
+ s_names = re.findall(r'((?:[\/]?[a-zA-Z]+[0-9\_]*)*)', compat.as_str_any(s))
for removed_op_name in removed_op_names:
- if removed_op_name in compat.as_str_any(s):
- return True
+ for s_name in s_names:
+ if s_name.endswith(removed_op_name):
+ return True
return False
@@ -455,6 +669,32 @@ def _check_tensor_not_removed(tensor_name, removed_op_names):
'Expected Tensor, but it was removed: {}'.format(tensor_name))
+def _add_new_inits_to_collection(meta_graph_def, updated_initializer_names):
+ """Add new inits to collection.
+
+ Args:
+ meta_graph_def: The MetaGraphDef protocol buffer to update.
+ updated_initializer_names: Dictionary of the updated "infrastructural" nodes
+ (initializers, save and restore ops, etc.). The keys in this dictionary
+ indicate the collection where these nodes were obtained from.
+
+ Raises:
+ ValueError: if the tensor was removed.
+ """
+ # TODO(dzats): Extend this to support all collections.
+ if _ops.GraphKeys.TABLE_INITIALIZERS in updated_initializer_names:
+ orig_table_inits = _get_all_node_names_from_collection(
+ meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
+ orig_table_inits = orig_table_inits if orig_table_inits else []
+ updated_table_inits = updated_initializer_names[
+ _ops.GraphKeys.TABLE_INITIALIZERS]
+ new_table_inits = list(set(updated_table_inits) - set(orig_table_inits))
+ new_table_inits.sort()
+ meta_graph_def.collection_def[
+ _ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
+ new_table_inits)
+
+
def meta_graph_transform(
base_meta_graph_def, input_names, output_names, transforms, tags,
checkpoint_path=None):
@@ -478,13 +718,9 @@ def meta_graph_transform(
initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)
- transformed_graph_def = _do_transforms(
- base_meta_graph_def.graph_def,
- input_names,
- output_names,
- initializer_names,
- transforms,
- base_meta_graph_def.saver_def,
+ transformed_graph_def, updated_initializer_names = _do_transforms(
+ base_meta_graph_def.graph_def, input_names, output_names,
+ initializer_names, transforms, base_meta_graph_def.saver_def,
checkpoint_path)
meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
@@ -503,7 +739,7 @@ def meta_graph_transform(
# TODO(b/63447631): Revisit this once the problem is addressed. Currently
# _add_pruned_saver assumes that the save and restore nodes have not been
# removed but freeze_graph (correctly) removes them.
- if _FREEZE_GRAPH_TRANSFORM_NAME not in transforms:
+ if _FREEZE_GRAPH_TRANSFORM not in transforms:
_add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)
# Copy collections, excluding any pruned nodes
@@ -512,6 +748,9 @@ def meta_graph_transform(
base_meta_graph_def, meta_graph_def, collection_name,
removed_op_names)
+ # Append newly added initalizers to collection.
+ _add_new_inits_to_collection(meta_graph_def, updated_initializer_names)
+
# Copy signature_defs, excluding any pruned nodes
for signature_name in base_meta_graph_def.signature_def:
_add_pruned_signature(
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py
index fdbe8eef3b..f3aec86be1 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py
@@ -23,7 +23,9 @@ from tensorflow.contrib.meta_graph_transform import meta_graph_transform
from tensorflow.core.framework import function_pb2
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.core.protobuf import meta_graph_pb2
+from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session as tf_session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -82,14 +84,292 @@ class MetaGraphTransformTest(test.TestCase):
self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def)
+ def test_get_shared_init_op(self):
+ main_op = 'main_op'
+ legacy_op = 'legacy_op'
+
+ legacy_only = {saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op]}
+ main_and_legacy = {
+ saved_model_constants.MAIN_OP_KEY: [main_op],
+ saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op]
+ }
+ self.assertEqual(meta_graph_transform._get_shared_init_op({}), None)
+ self.assertEqual(
+ meta_graph_transform._get_shared_init_op(main_and_legacy), main_op)
+ self.assertEqual(
+ meta_graph_transform._get_shared_init_op(legacy_only), legacy_op)
+
@test.mock.patch.object(graph_transforms, 'TransformGraph')
+ def test_gtt_transforms(self, graph_transform_mock):
+ graph_def = graph_pb2.GraphDef()
+ graph_def.node.extend([node_def_pb2.NodeDef(name='z1', op='NoOp')])
+ input_names = ['i1', 'i2']
+ output_names = ['o1', 'o2']
+ init_nodes = ['init1', 'init2']
+ initializer_names = {'init': init_nodes}
+ transforms = ['t1', 't2']
+
+ expected_graph = graph_pb2.GraphDef()
+ expected_graph.node.extend([node_def_pb2.NodeDef(name='n1', op='NoOp')])
+ graph_transform_mock.return_value = expected_graph
+ transformed_graph_def = (meta_graph_transform._gtt_transforms(
+ graph_def, input_names, output_names, initializer_names, transforms))
+
+ self.assertEqual(transformed_graph_def, expected_graph)
+ graph_transform_mock.assert_called_once_with(
+ graph_def, input_names, output_names + init_nodes, transforms)
+
@test.mock.patch.object(meta_graph_transform, '_freeze_graph_with_def_protos')
- def test_freeze(self, freeze_mock, graph_transform_mock):
+ def test_freeze_transform(self, freeze_mock):
+ graph_def = graph_pb2.GraphDef()
+ graph_def.node.extend([node_def_pb2.NodeDef(name='z1', op='NoOp')])
+ output_names = ['o1', 'o2']
+ table_init_names = ['t1', 't2']
+ main_op = 'main_op'
+ legacy_op = 'legacy_op'
+ initializer_names = {
+ 'foo_init': ['init1', 'init2'],
+ ops.GraphKeys.TABLE_INITIALIZERS: table_init_names,
+ saved_model_constants.MAIN_OP_KEY: [main_op],
+ saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op]
+ }
+ expected_graph_def = graph_pb2.GraphDef()
+ graph_def.node.extend([node_def_pb2.NodeDef(name='n1', op='NoOp')])
+ freeze_mock.return_value = expected_graph_def
+ saver_def = saver_pb2.SaverDef()
+ saver_def.filename_tensor_name = 'f1'
+ checkpoint_path = '/checkpoint/path'
+ transformed_graph_def, transformed_initializer_names = (
+ meta_graph_transform._freeze_transform(graph_def, output_names,
+ initializer_names, saver_def,
+ checkpoint_path))
+ self.assertEqual(transformed_graph_def, expected_graph_def)
+ expected_initializer_names = {
+ ops.GraphKeys.TABLE_INITIALIZERS: table_init_names,
+ saved_model_constants.MAIN_OP_KEY: [main_op],
+ saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op]
+ }
+ self.assertEqual(transformed_initializer_names, expected_initializer_names)
+ freeze_mock.assert_called_once_with(graph_def, output_names,
+ table_init_names, main_op, saver_def,
+ checkpoint_path)
+
+ def test_clean_save_and_restore(self):
+ graph_def = graph_pb2.GraphDef()
+ save_name = 'save_1/SaveV2'
+ save_tensor_name = save_name + '/tensor_names'
+ save_tensor_shape = save_name + '/shape_and_slices'
+ save_op = graph_def.node.add()
+ save_op.name = save_name
+ save_op.op = 'NoOp'
+ save_name_op = graph_def.node.add()
+ save_name_op.name = save_tensor_name
+ save_name_op.op = 'NoOp'
+ save_shape_op = graph_def.node.add()
+ save_shape_op.name = save_tensor_shape
+ save_shape_op.op = 'NoOp'
+
+ types = [types_pb2.DT_INT32, types_pb2.DT_FLOAT, types_pb2.DT_INT32]
+ names = [
+ compat.as_bytes('/foo'),
+ compat.as_bytes('/bar'),
+ compat.as_bytes('/baz')
+ ]
+ shapes = [
+ compat.as_bytes('100 10 0,100:0,10'),
+ compat.as_bytes('150 11 0,150:0,11'),
+ compat.as_bytes('101 12 0,101:0,12')
+ ]
+
+ expected_types = [types[0], types[2]]
+ expected_names = [names[0], names[2]]
+ expected_shapes = [shapes[0], shapes[2]]
+
+ save_op.attr['dtypes'].list.type[:] = types
+ save_name_op.attr['value'].tensor.string_val[:] = names
+ save_name_op.attr['value'].tensor.tensor_shape.dim.add().size = len(names)
+ save_name_op.attr['_output_shapes'].list.shape.add().dim.add().size = len(
+ names)
+
+ save_shape_op.attr['value'].tensor.string_val[:] = shapes
+ save_shape_op.attr['value'].tensor.tensor_shape.dim.add().size = len(shapes)
+ save_shape_op.attr['_output_shapes'].list.shape.add().dim.add().size = len(
+ shapes)
+
+ meta_graph_transform._clean_save_and_restore(graph_def, save_op, ['/bar'])
+ self.assertEqual(save_op.attr['dtypes'].list.type[:], expected_types)
+ self.assertEqual(save_name_op.attr['value'].tensor.string_val[:],
+ expected_names)
+ self.assertEqual(save_name_op.attr['value'].tensor.tensor_shape.dim[0].size,
+ len(expected_names))
+ self.assertEqual(
+ save_name_op.attr['_output_shapes'].list.shape[0].dim[0].size,
+ len(expected_names))
+
+ self.assertEqual(save_shape_op.attr['value'].tensor.string_val[:],
+ expected_shapes)
+ self.assertEqual(
+ save_shape_op.attr['value'].tensor.tensor_shape.dim[0].size,
+ len(expected_shapes))
+ self.assertEqual(
+ save_shape_op.attr['_output_shapes'].list.shape[0].dim[0].size,
+ len(expected_shapes))
+
+ @test.mock.patch.object(meta_graph_transform, '_clean_save_and_restore')
+ @test.mock.patch.object(meta_graph_transform, '_gtt_transforms')
+ def test_sparsify_gather_transform(self, gtt_mock, clean_save_restore_mock):
+ # Initial graph def.
+ graph_def = graph_pb2.GraphDef()
+ variable_op = graph_def.node.add()
+ variable_op.name = '/foo/part_1'
+
+ constant_op = graph_def.node.add()
+ constant_op.name = '/bar'
+
+ # Transformed graph def.
+ transformed_graph_def = graph_pb2.GraphDef()
+ constant_op = transformed_graph_def.node.add()
+ constant_op.name = '/foo'
+
+ sparsify_shared_init_op_name = 'sparify_gather_init_op'
+ new_table_init_names = ['table1', 'table2']
+ init_op = transformed_graph_def.node.add()
+ init_op.name = sparsify_shared_init_op_name
+ init_op.input.extend(['^' + f for f in new_table_init_names])
+
+ saver_op = transformed_graph_def.node.add()
+ saver_op.name = 'save_1/SaveV2'
+
+ orig_table_init_names = ['orig_table_init_1', 'orig_table_init_2']
+
+ legacy_op_name = 'legacy_op'
+ legacy_op = transformed_graph_def.node.add()
+ legacy_op.name = legacy_op_name
+ legacy_op.input.extend(['^' + f for f in orig_table_init_names])
+
+ input_names = ['i1', 'i2']
+ output_names = ['o1', 'o2']
+
+ initializer_names = {
+ 'foo_init': ['init1', 'init2'],
+ ops.GraphKeys.TABLE_INITIALIZERS: orig_table_init_names,
+ saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op_name]
+ }
+ checkpoint_path = '/path/to/checkpoint'
+
+ expected_initializer_names = {
+ 'foo_init': ['init1', 'init2'],
+ ops.GraphKeys.TABLE_INITIALIZERS: (
+ orig_table_init_names + new_table_init_names),
+ saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op_name]
+ }
+
+ expected_sparsify_cmd = [
+ 'sparsify_gather(input_checkpoint="%s", group_init_node="%s")' %
+ (checkpoint_path, sparsify_shared_init_op_name)
+ ]
+
+ # Expected graph def.
+ expected_graph_def = graph_pb2.GraphDef()
+ constant_op = expected_graph_def.node.add()
+ constant_op.name = '/foo'
+
+ saver_op = expected_graph_def.node.add()
+ saver_op.name = 'save_1/SaveV2'
+
+ legacy_op_name = 'legacy_op'
+ legacy_op = expected_graph_def.node.add()
+ legacy_op.name = legacy_op_name
+ legacy_op.input.extend(
+ ['^' + f for f in orig_table_init_names + new_table_init_names])
+
+ gtt_mock.return_value = transformed_graph_def
+ graph_def_result, init_names_result = (
+ meta_graph_transform._sparsify_gather_transform(
+ graph_def, input_names, output_names, initializer_names,
+ checkpoint_path))
+
+ gtt_mock.assert_called_once_with(graph_def, input_names, output_names,
+ initializer_names, expected_sparsify_cmd)
+
+ clean_save_restore_mock.assert_called_once_with(transformed_graph_def,
+ saver_op, ['/bar', '/foo'])
+
+ self.assertEqual(expected_graph_def, graph_def_result)
+ self.assertEqual(expected_initializer_names, init_names_result)
+
+ @test.mock.patch.object(meta_graph_transform, '_gtt_transforms')
+ @test.mock.patch.object(meta_graph_transform, '_freeze_transform')
+ @test.mock.patch.object(meta_graph_transform, '_sparsify_gather_transform')
+ def test_do_transforms(self, sparsify_mock, freeze_mock, gtt_mock):
+ graph_def = graph_pb2.GraphDef()
+ constant_op = graph_def.node.add()
+ constant_op.name = 'c1'
+
+ input_names = ['i1', 'i2']
+ output_names = ['o1', 'o2']
+ initializer_names = {
+ 'foo_init': ['init1', 'init2'],
+ ops.GraphKeys.TABLE_INITIALIZERS: ['table1'],
+ saved_model_constants.LEGACY_INIT_OP_KEY: ['legacy_op']
+ }
+
+ transforms = ['foo', 'freeze_graph', 'bar', 'sparsify_gather', 'baz']
+
+ sparsify_mock.return_value = (graph_def, initializer_names)
+ freeze_mock.return_value = (graph_def, initializer_names)
+ gtt_mock.return_value = graph_def
+
+ graph_def_result, initializer_names_result = (
+ meta_graph_transform._do_transforms(graph_def, input_names,
+ output_names, initializer_names,
+ transforms))
+
+ sparsify_mock.assert_called_once_with(graph_def, input_names, output_names,
+ initializer_names, None)
+
+ freeze_mock.assert_called_once_with(graph_def, output_names,
+ initializer_names, None, None)
+
+ gtt_mock.assert_has_calls([
+ test.mock.call(graph_def, input_names, output_names, initializer_names,
+ ['foo']),
+ test.mock.call(graph_def, input_names, output_names, initializer_names,
+ ['bar']),
+ test.mock.call(graph_def, input_names, output_names, initializer_names,
+ ['baz'])
+ ])
+ self.assertEqual(graph_def_result, graph_def)
+ self.assertEqual(initializer_names, initializer_names_result)
+
+ def test_add_new_inits_to_collection(self):
+ meta_graph_def = meta_graph_pb2.MetaGraphDef()
+
+ orig_table_inits = ['t1', 't2']
+ new_table_inits = ['t3', 't4']
+
+ meta_graph_def.collection_def[
+ ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
+ orig_table_inits)
+ updated_init_names = {
+ ops.GraphKeys.TABLE_INITIALIZERS: orig_table_inits + new_table_inits
+ }
+
+ meta_graph_transform._add_new_inits_to_collection(meta_graph_def,
+ updated_init_names)
+
+ self.assertEqual(meta_graph_def.collection_def[
+ ops.GraphKeys.TABLE_INITIALIZERS].node_list.value,
+ orig_table_inits + new_table_inits)
+
+ @test.mock.patch.object(graph_transforms, 'TransformGraph')
+ @test.mock.patch.object(meta_graph_transform, '_freeze_graph_with_def_protos')
+ def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock):
tag_name = 'tag'
input_nodes = 'input_nodes'
output_nodes = 'output_nodes'
freeze_transform = 'freeze_graph'
- sparsify_transform = 'sparsify_graph'
+ sparsify_transform = 'sparsify_gather'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
@@ -141,8 +421,9 @@ class MetaGraphTransformTest(test.TestCase):
base_meta_graph_def.graph_def, [output_nodes], [table_init_name],
group_deps_name, base_meta_graph_def.saver_def, None)
graph_transform_mock.assert_called_once_with(
- transformed_graph_def, [input_nodes],
- [output_nodes, group_deps_name, table_init_name], [sparsify_transform])
+ transformed_graph_def, [input_nodes], [
+ output_nodes, group_deps_name, table_init_name
+ ], [sparsify_transform + '(group_init_node="sparify_gather_init_op")'])
def test_connect_to_shared_init_op(self):
group_deps_name = 'group_deps'
@@ -166,19 +447,20 @@ class MetaGraphTransformTest(test.TestCase):
self.assertEqual(expected_graph_def_2, orig_graph_def)
def test_add_pruned_collection_node(self):
+ # Note: This also tests _is_removed().
collection_name = 'node_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].node_list.value.extend(
- ['node1', 'node2', 'node3', 'node4'])
+ ['node1', 'node2', 'node3', 'node4', '/a/a_1', '/b/b_1'])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
- removed_op_names = ['node2', 'node4', 'node5']
+ removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
- expected_nodes = ['node1', 'node3']
+ expected_nodes = ['node1', 'node3', '/a/a_1']
self.assertEqual(expected_nodes, collection.node_list.value)
def test_add_pruned_collection_int(self):
@@ -188,7 +470,7 @@ class MetaGraphTransformTest(test.TestCase):
[10, 20, 30, 40])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
- removed_op_names = ['node2', 'node4', 'node5']
+ removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
@@ -198,36 +480,47 @@ class MetaGraphTransformTest(test.TestCase):
self.assertEqual(expected_ints, collection.int64_list.value)
def test_add_pruned_collection_proto_in_any_list(self):
+ # Note: This also tests _is_removed_mentioned().
collection_name = 'proto_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
- base_meta_graph_def.collection_def[collection_name].any_list.value.extend(
- [_make_asset_file_def_any('node1'),
- _make_asset_file_def_any('node2'),
- _make_asset_file_def_any('node3'),
- _make_asset_file_def_any('node4')])
+ base_meta_graph_def.collection_def[collection_name].any_list.value.extend([
+ _make_asset_file_def_any('node1'),
+ _make_asset_file_def_any('node2'),
+ _make_asset_file_def_any('node3'),
+ _make_asset_file_def_any('node4'),
+ _make_asset_file_def_any('/a/a_1'),
+ _make_asset_file_def_any('/b/b_1')
+ ])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
- removed_op_names = ['node2', 'node4', 'node5']
+ removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
collection = meta_graph_def.collection_def[collection_name]
- expected_protos = [_make_asset_file_def_any('node1'),
- _make_asset_file_def_any('node3')]
+ expected_protos = [
+ _make_asset_file_def_any('node1'),
+ _make_asset_file_def_any('node3'),
+ _make_asset_file_def_any('/a/a_1'),
+ ]
self.assertEqual(expected_protos, collection.any_list.value[:])
def test_add_pruned_collection_proto_in_bytes_list(self):
+ # Note: This also tests _is_removed_mentioned().
collection_name = 'proto_collection'
base_meta_graph_def = meta_graph_pb2.MetaGraphDef()
base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend(
[compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))),
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
- compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))])
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/b/b_1')))
+ ])
meta_graph_def = meta_graph_pb2.MetaGraphDef()
- removed_op_names = ['node2', 'node4', 'node5']
+ removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1']
meta_graph_transform._add_pruned_collection(
base_meta_graph_def, meta_graph_def, collection_name, removed_op_names)
@@ -235,7 +528,9 @@ class MetaGraphTransformTest(test.TestCase):
expected_values = [
compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))),
- compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))]
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))),
+ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))),
+ ]
self.assertEqual(expected_values, collection.bytes_list.value[:])
def test_add_pruned_saver(self):