diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-09-16 16:28:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-09-16 16:32:17 -0700 |
commit | 20e12d9c160ca2c5a20d238fba2d54e1e16741a5 (patch) | |
tree | 3a1f004e4c3b02fd976fb8f7e86777d6560e246b /tensorflow/contrib/meta_graph_transform | |
parent | c2749f90b08314c3ae47289ebe803a28f601ad49 (diff) |
Enabling sparsify_gather to be called before freeze.
PiperOrigin-RevId: 168971905
Diffstat (limited to 'tensorflow/contrib/meta_graph_transform')
-rw-r--r-- | tensorflow/contrib/meta_graph_transform/meta_graph_transform.py | 351 | ||||
-rw-r--r-- | tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py | 333 |
2 files changed, 609 insertions, 75 deletions
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py index e7849d49a7..72494f54f5 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py @@ -20,6 +20,8 @@ from __future__ import division from __future__ import print_function +import re + from tensorflow.core.framework import graph_pb2 as _graph_pb2 from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2 from tensorflow.python.client import session as _session @@ -32,7 +34,8 @@ from tensorflow.python.util import compat from tensorflow.tools import graph_transforms as _graph_transforms -_FREEZE_GRAPH_TRANSFORM_NAME = 'freeze_graph' +_FREEZE_GRAPH_TRANSFORM = 'freeze_graph' +_SPARSIFY_GATHER_TRANSFORM = 'sparsify_gather' def _op_name(tensor_name): @@ -46,11 +49,27 @@ def _op_name(tensor_name): return tensor_name -def _do_transforms(graph_def, input_names, output_names, initializer_names, - transforms, saver_def=None, checkpoint_path=None): - """Apply requested transforms to a GraphDef, including freezing. +def _get_shared_init_op(initializer_names): + """Obtain the shared init op name, if it exists. + + Args: + initializer_names: Dictionary of the "infrastructural" nodes (initializers, + save and restore ops, etc.). The keys in this dictionary + indicate the collection where these nodes were obtained from. + + Returns: + A string indicating the shared init op name or none if None if none exists. + """ + return_value = initializer_names.get(_saved_model_constants.MAIN_OP_KEY, None) + if not return_value: + return_value = initializer_names.get( + _saved_model_constants.LEGACY_INIT_OP_KEY, None) + return str(return_value[0]) if return_value else None + - This applies the Graph Transform Tool interleaved with graph freezing. +def _gtt_transforms(graph_def, input_names, output_names, initializer_names, + transforms): + """Pass through gtt transforms, applying them to the graph_def. Args: graph_def: A GraphDef proto to be transformed. @@ -61,12 +80,7 @@ def _do_transforms(graph_def, input_names, output_names, initializer_names, transitively reachable from output nodes. The keys in this dictionary indicate the collection where these nodes were obtained from. transforms: A list of strings naming the graph transforms to be applied in - order. These transform names are exactly those supported by the Graph - Transform Tool, with the addition of the 'freeze_graph' transform. - saver_def: A SaverDef proto used for restoring a checkpoint during freezing, - if needed (default None). - checkpoint_path: A path to a checkpoint to restore during freezing, - if needed (default None). + order. Returns: The transformed GraphDef. """ @@ -74,42 +88,233 @@ def _do_transforms(graph_def, input_names, output_names, initializer_names, transformed_graph_def = _graph_pb2.GraphDef() transformed_graph_def.CopyFrom(graph_def) return transformed_graph_def - else: - try: - freeze_index = transforms.index(_FREEZE_GRAPH_TRANSFORM_NAME) - except ValueError: - # No freeze_graph requested, so do all transforms in one go. - initializer_names_flat = sorted( - [k for l in initializer_names.values() for k in l]) - all_output_names = output_names + initializer_names_flat - return _graph_transforms.TransformGraph( - graph_def, input_names, all_output_names, transforms) - - # freeze_graph requested, possibly with transforms before and after. - phase_1_transforms = transforms[:freeze_index] - phase_2_transforms = transforms[freeze_index+1:] - - graph_def = _do_transforms( - graph_def, input_names, output_names, initializer_names, - phase_1_transforms, saver_def, checkpoint_path) - output_node_names = [_op_name(x) for x in output_names] - graph_def = _freeze_graph_with_def_protos( - graph_def, output_node_names, - initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS], - initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY][0], - saver_def, checkpoint_path) - # No need for saver or checkpoint anymore - pruned_initializer_names = {} - # Freeze graph will prune all initializers and shared init nodes if table - # initializers are not present. Handle this case in future GTT transforms. - if initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS]: - pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = ( - initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS]) + + initializer_names_flat = sorted( + [k for l in initializer_names.values() for k in l]) + all_output_names = output_names + initializer_names_flat + return _graph_transforms.TransformGraph(graph_def, input_names, + all_output_names, transforms) + + +def _freeze_transform(graph_def, output_names, initializer_names, saver_def, + checkpoint_path): + """Handle the freeze transform. + + Determine which initializer nodes should be retained by the freeze transform. + Retain those nodes and return an updated dictionary containing them. + + Args: + graph_def: A GraphDef proto to be transformed. + output_names: Names of output nodes. + initializer_names: Dictionary of the "infrastructural" nodes (initializers, + save and restore ops, etc.). The keys in this dictionary + indicate the collection where these nodes were obtained from. + saver_def: A SaverDef proto used for restoring a checkpoint during freezing, + if needed (default None). + checkpoint_path: A path to a checkpoint to restore during freezing, + if needed (default None). + + Returns: + A tuple containing the GraphDef and a Dict of pruned initializer nodes. + """ + table_initializers = initializer_names.get(_ops.GraphKeys.TABLE_INITIALIZERS, + []) + shared_init_op = _get_shared_init_op(initializer_names) + + graph_def = _freeze_graph_with_def_protos(graph_def, output_names, + table_initializers, shared_init_op, + saver_def, checkpoint_path) + pruned_initializer_names = {} + # Freeze graph prunes all initializers and shared init nodes that are not + # explicitly maintained. Create new initializer_names dictionary to reflect + # this. + if table_initializers: + pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = ( + table_initializers) + if _saved_model_constants.LEGACY_INIT_OP_KEY in initializer_names: pruned_initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY] = ( initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY]) + if _saved_model_constants.MAIN_OP_KEY in initializer_names: + pruned_initializer_names[_saved_model_constants.MAIN_OP_KEY] = ( + initializer_names[_saved_model_constants.MAIN_OP_KEY]) + return (graph_def, pruned_initializer_names) + + +def _clean_save_and_restore(graph_def, op, removed_op_names): + """Clean the specified save and restore op. + + Updates the dtypes attribute of the save / restore op and the associated name + and shape tensors to remove entries for variables that have been removed. + + Args: + graph_def: A GraphDef proto to be transformed. + op: The save or restore op to update. + removed_op_names: List of op names that have been removed. + """ + name = op.name + '/tensor_names' + shape = op.name + '/shape_and_slices' + name_op = _find_op(graph_def, name) + shape_op = _find_op(graph_def, shape) + name_op_value_tensor = name_op.attr['value'].tensor + shape_op_value_tensor = shape_op.attr['value'].tensor + names = [] + shapes = [] + dtypes = [] + for index, value in enumerate(name_op_value_tensor.string_val): + if not _is_removed(compat.as_str(value), removed_op_names): + names.append(value) + shapes.append(shape_op_value_tensor.string_val[index]) + dtypes.append(op.attr['dtypes'].list.type[index]) + name_op_value_tensor.string_val[:] = names + name_op_value_tensor.tensor_shape.dim[0].size = len(names) + shape_op_value_tensor.string_val[:] = shapes + shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes) + op.attr['dtypes'].list.type[:] = dtypes + + name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names) + shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes) + + +def _sparsify_gather_transform(graph_def, input_names, output_names, + initializer_names, checkpoint_path): + """Handle the sparsify gather transform. + + Provides the transform the checkpoint and keeps track of the newly created + initializer nodes. + + Args: + graph_def: A GraphDef proto to be transformed. + input_names: Names of input nodes. + output_names: Names of output nodes. + initializer_names: Dictionary of the "infrastructural" nodes (initializers, + save and restore ops, etc.). The keys in this dictionary + indicate the collection where these nodes were obtained from. + checkpoint_path: A path to a checkpoint. + + Returns: + A tuple containing the GraphDef and a Dict of updated initializer nodes. + Raises: + ValueError: if the restore_op_name does not have the expected format. + """ + # Ensure that sparsify_shared_init_op is unique. + sparsify_shared_init_op = 'sparify_gather_init_op' + while _find_op(graph_def, sparsify_shared_init_op): + sparsify_shared_init_op += '_1' + + input_flag = '' + if checkpoint_path: + input_flag = 'input_checkpoint="%s", ' % checkpoint_path + + sparsify_cmd = [ + 'sparsify_gather(%sgroup_init_node="%s")' % (input_flag, + sparsify_shared_init_op) + ] + + starting_op_names = [node.name for node in graph_def.node] + + graph_def = _gtt_transforms(graph_def, input_names, output_names, + initializer_names, sparsify_cmd) + ending_op_names = [node.name for node in graph_def.node] + removed_op_names = list(set(starting_op_names) - set(ending_op_names)) + removed_op_names.sort() + + for op_index, op_name in enumerate(removed_op_names): + op_name_parts = op_name.rsplit('/', 1) + # Remove part to get the checkpoint names used by the saver. + if len(op_name_parts) == 2 and op_name_parts[1].startswith('part_'): + removed_op_names[op_index] = op_name_parts[0] + else: + removed_op_names[op_index] = op_name + + # Obtain newly created table inits from gtt sparsify transform. + added_table_inits = [] + for index, node in enumerate(graph_def.node): + if node.name == sparsify_shared_init_op: + added_table_inits = [n.lstrip('^') for n in node.input] + + table_initializers = initializer_names.get( + _ops.GraphKeys.TABLE_INITIALIZERS, []) + table_initializers.extend(added_table_inits) + initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers + + del graph_def.node[index] + break + + # Add inits to existing shared init op. + node = _find_op(graph_def, _get_shared_init_op(initializer_names)) + for init in added_table_inits: + node.input.append('^' + init) + + # Update saver. + for node in graph_def.node: + if node.name.endswith('SaveV2'): + _clean_save_and_restore(graph_def, node, removed_op_names) + + return (graph_def, initializer_names) + + +def _do_transforms(graph_def, + input_names, + output_names, + initializer_names, + transforms, + saver_def=None, + checkpoint_path=None): + """Apply requested transforms to a GraphDef, including freezing. + + Args: + graph_def: A GraphDef proto to be transformed. + input_names: Names of input nodes. + output_names: Names of output nodes. + initializer_names: Dictionary of the "infrastructural" nodes (initializers, + save and restore ops, etc.) that should be retained even if they are not + transitively reachable from output nodes. The keys in this dictionary + indicate the collection where these nodes were obtained from. + transforms: A list of strings naming the graph transforms to be applied in + order. These transform names are exactly those supported by the Graph + Transform Tool, with the addition of the 'freeze_graph' and + 'sparsify_gather' transforms. + saver_def: A SaverDef proto used for restoring a checkpoint during freezing, + if needed (default None). + checkpoint_path: A path to a checkpoint to restore during freezing, + if needed (default None). + Returns: + A tuple containing the GraphDef and a Dict of updated initializer nodes. + """ + transformed_graph_def = _graph_pb2.GraphDef() + transformed_graph_def.CopyFrom(graph_def) + transformed_initializer_names = initializer_names.copy() - return _do_transforms(graph_def, input_names, output_names, - pruned_initializer_names, phase_2_transforms) + if not transforms: + return transformed_graph_def, transformed_initializer_names + + current_gtt_transforms = [] + for t in transforms: + if t == _FREEZE_GRAPH_TRANSFORM: + transformed_graph_def = _gtt_transforms( + transformed_graph_def, input_names, output_names, + transformed_initializer_names, current_gtt_transforms) + output_node_names = [_op_name(x) for x in output_names] + transformed_graph_def, transformed_initializer_names = _freeze_transform( + transformed_graph_def, output_node_names, + transformed_initializer_names, saver_def, checkpoint_path) + current_gtt_transforms = [] + elif t == _SPARSIFY_GATHER_TRANSFORM: + transformed_graph_def = _gtt_transforms( + transformed_graph_def, input_names, output_names, + transformed_initializer_names, current_gtt_transforms) + transformed_graph_def, transformed_initializer_names = ( + _sparsify_gather_transform( + transformed_graph_def, input_names, output_names, + transformed_initializer_names, checkpoint_path)) + current_gtt_transforms = [] + else: + current_gtt_transforms.append(t) + + transformed_graph_def = _gtt_transforms( + transformed_graph_def, input_names, output_names, + transformed_initializer_names, current_gtt_transforms) + return transformed_graph_def, transformed_initializer_names def _connect_to_shared_init_op(graph_def, shared_init_op_name, @@ -296,7 +501,8 @@ def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names): # TODO(b/63447631): Once we strip unused variables, remove references to # them from save and restore ops. Retain those ops only if they also refer - # to retained Variables. + # to retained Variables. See if we can use _clean_save_and_restore() for + # this. # saver_name, restore_all = restore_op_name.rsplit('/', 1) # if restore_all != 'restore_all': @@ -412,7 +618,7 @@ def _get_all_protos_from_collection(meta_graph_def, collection_key): def _is_removed(tensor_name, removed_op_names): """Determine whether the named tensor is an output of a removed op.""" for removed_op_name in removed_op_names: - if tensor_name.startswith(removed_op_name): + if tensor_name.split(':')[0] == removed_op_name: return True return False @@ -432,9 +638,17 @@ def _is_removed_mentioned(s, removed_op_names): Returns: True if any removed op is mentioned in the given object, False otherwise. """ + # A common approach taken by some of the transforms in gtt is to add new nodes + # that have the same prefix as the node they are removing. For example, if + # the original node name was /foo, they may remove that node and add in + # /foo/bar. This regex ensures that we handle these two nodes + # as separate entities. It matches on nodes having names in the form of + # '/foo/bar_x' as well as nodes having names in the form of 'foo.' + s_names = re.findall(r'((?:[\/]?[a-zA-Z]+[0-9\_]*)*)', compat.as_str_any(s)) for removed_op_name in removed_op_names: - if removed_op_name in compat.as_str_any(s): - return True + for s_name in s_names: + if s_name.endswith(removed_op_name): + return True return False @@ -455,6 +669,32 @@ def _check_tensor_not_removed(tensor_name, removed_op_names): 'Expected Tensor, but it was removed: {}'.format(tensor_name)) +def _add_new_inits_to_collection(meta_graph_def, updated_initializer_names): + """Add new inits to collection. + + Args: + meta_graph_def: The MetaGraphDef protocol buffer to update. + updated_initializer_names: Dictionary of the updated "infrastructural" nodes + (initializers, save and restore ops, etc.). The keys in this dictionary + indicate the collection where these nodes were obtained from. + + Raises: + ValueError: if the tensor was removed. + """ + # TODO(dzats): Extend this to support all collections. + if _ops.GraphKeys.TABLE_INITIALIZERS in updated_initializer_names: + orig_table_inits = _get_all_node_names_from_collection( + meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS) + orig_table_inits = orig_table_inits if orig_table_inits else [] + updated_table_inits = updated_initializer_names[ + _ops.GraphKeys.TABLE_INITIALIZERS] + new_table_inits = list(set(updated_table_inits) - set(orig_table_inits)) + new_table_inits.sort() + meta_graph_def.collection_def[ + _ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( + new_table_inits) + + def meta_graph_transform( base_meta_graph_def, input_names, output_names, transforms, tags, checkpoint_path=None): @@ -478,13 +718,9 @@ def meta_graph_transform( initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def) - transformed_graph_def = _do_transforms( - base_meta_graph_def.graph_def, - input_names, - output_names, - initializer_names, - transforms, - base_meta_graph_def.saver_def, + transformed_graph_def, updated_initializer_names = _do_transforms( + base_meta_graph_def.graph_def, input_names, output_names, + initializer_names, transforms, base_meta_graph_def.saver_def, checkpoint_path) meta_graph_def.graph_def.CopyFrom(transformed_graph_def) @@ -503,7 +739,7 @@ def meta_graph_transform( # TODO(b/63447631): Revisit this once the problem is addressed. Currently # _add_pruned_saver assumes that the save and restore nodes have not been # removed but freeze_graph (correctly) removes them. - if _FREEZE_GRAPH_TRANSFORM_NAME not in transforms: + if _FREEZE_GRAPH_TRANSFORM not in transforms: _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names) # Copy collections, excluding any pruned nodes @@ -512,6 +748,9 @@ def meta_graph_transform( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) + # Append newly added initalizers to collection. + _add_new_inits_to_collection(meta_graph_def, updated_initializer_names) + # Copy signature_defs, excluding any pruned nodes for signature_name in base_meta_graph_def.signature_def: _add_pruned_signature( diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py index fdbe8eef3b..f3aec86be1 100644 --- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py +++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform_test.py @@ -23,7 +23,9 @@ from tensorflow.contrib.meta_graph_transform import meta_graph_transform from tensorflow.core.framework import function_pb2 from tensorflow.core.framework import graph_pb2 from tensorflow.core.framework import node_def_pb2 +from tensorflow.core.framework import types_pb2 from tensorflow.core.protobuf import meta_graph_pb2 +from tensorflow.core.protobuf import saver_pb2 from tensorflow.python.client import session as tf_session from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops @@ -82,14 +84,292 @@ class MetaGraphTransformTest(test.TestCase): self.assertEqual(expected_meta_graph_def, transformed_meta_graph_def) + def test_get_shared_init_op(self): + main_op = 'main_op' + legacy_op = 'legacy_op' + + legacy_only = {saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op]} + main_and_legacy = { + saved_model_constants.MAIN_OP_KEY: [main_op], + saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op] + } + self.assertEqual(meta_graph_transform._get_shared_init_op({}), None) + self.assertEqual( + meta_graph_transform._get_shared_init_op(main_and_legacy), main_op) + self.assertEqual( + meta_graph_transform._get_shared_init_op(legacy_only), legacy_op) + @test.mock.patch.object(graph_transforms, 'TransformGraph') + def test_gtt_transforms(self, graph_transform_mock): + graph_def = graph_pb2.GraphDef() + graph_def.node.extend([node_def_pb2.NodeDef(name='z1', op='NoOp')]) + input_names = ['i1', 'i2'] + output_names = ['o1', 'o2'] + init_nodes = ['init1', 'init2'] + initializer_names = {'init': init_nodes} + transforms = ['t1', 't2'] + + expected_graph = graph_pb2.GraphDef() + expected_graph.node.extend([node_def_pb2.NodeDef(name='n1', op='NoOp')]) + graph_transform_mock.return_value = expected_graph + transformed_graph_def = (meta_graph_transform._gtt_transforms( + graph_def, input_names, output_names, initializer_names, transforms)) + + self.assertEqual(transformed_graph_def, expected_graph) + graph_transform_mock.assert_called_once_with( + graph_def, input_names, output_names + init_nodes, transforms) + @test.mock.patch.object(meta_graph_transform, '_freeze_graph_with_def_protos') - def test_freeze(self, freeze_mock, graph_transform_mock): + def test_freeze_transform(self, freeze_mock): + graph_def = graph_pb2.GraphDef() + graph_def.node.extend([node_def_pb2.NodeDef(name='z1', op='NoOp')]) + output_names = ['o1', 'o2'] + table_init_names = ['t1', 't2'] + main_op = 'main_op' + legacy_op = 'legacy_op' + initializer_names = { + 'foo_init': ['init1', 'init2'], + ops.GraphKeys.TABLE_INITIALIZERS: table_init_names, + saved_model_constants.MAIN_OP_KEY: [main_op], + saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op] + } + expected_graph_def = graph_pb2.GraphDef() + graph_def.node.extend([node_def_pb2.NodeDef(name='n1', op='NoOp')]) + freeze_mock.return_value = expected_graph_def + saver_def = saver_pb2.SaverDef() + saver_def.filename_tensor_name = 'f1' + checkpoint_path = '/checkpoint/path' + transformed_graph_def, transformed_initializer_names = ( + meta_graph_transform._freeze_transform(graph_def, output_names, + initializer_names, saver_def, + checkpoint_path)) + self.assertEqual(transformed_graph_def, expected_graph_def) + expected_initializer_names = { + ops.GraphKeys.TABLE_INITIALIZERS: table_init_names, + saved_model_constants.MAIN_OP_KEY: [main_op], + saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op] + } + self.assertEqual(transformed_initializer_names, expected_initializer_names) + freeze_mock.assert_called_once_with(graph_def, output_names, + table_init_names, main_op, saver_def, + checkpoint_path) + + def test_clean_save_and_restore(self): + graph_def = graph_pb2.GraphDef() + save_name = 'save_1/SaveV2' + save_tensor_name = save_name + '/tensor_names' + save_tensor_shape = save_name + '/shape_and_slices' + save_op = graph_def.node.add() + save_op.name = save_name + save_op.op = 'NoOp' + save_name_op = graph_def.node.add() + save_name_op.name = save_tensor_name + save_name_op.op = 'NoOp' + save_shape_op = graph_def.node.add() + save_shape_op.name = save_tensor_shape + save_shape_op.op = 'NoOp' + + types = [types_pb2.DT_INT32, types_pb2.DT_FLOAT, types_pb2.DT_INT32] + names = [ + compat.as_bytes('/foo'), + compat.as_bytes('/bar'), + compat.as_bytes('/baz') + ] + shapes = [ + compat.as_bytes('100 10 0,100:0,10'), + compat.as_bytes('150 11 0,150:0,11'), + compat.as_bytes('101 12 0,101:0,12') + ] + + expected_types = [types[0], types[2]] + expected_names = [names[0], names[2]] + expected_shapes = [shapes[0], shapes[2]] + + save_op.attr['dtypes'].list.type[:] = types + save_name_op.attr['value'].tensor.string_val[:] = names + save_name_op.attr['value'].tensor.tensor_shape.dim.add().size = len(names) + save_name_op.attr['_output_shapes'].list.shape.add().dim.add().size = len( + names) + + save_shape_op.attr['value'].tensor.string_val[:] = shapes + save_shape_op.attr['value'].tensor.tensor_shape.dim.add().size = len(shapes) + save_shape_op.attr['_output_shapes'].list.shape.add().dim.add().size = len( + shapes) + + meta_graph_transform._clean_save_and_restore(graph_def, save_op, ['/bar']) + self.assertEqual(save_op.attr['dtypes'].list.type[:], expected_types) + self.assertEqual(save_name_op.attr['value'].tensor.string_val[:], + expected_names) + self.assertEqual(save_name_op.attr['value'].tensor.tensor_shape.dim[0].size, + len(expected_names)) + self.assertEqual( + save_name_op.attr['_output_shapes'].list.shape[0].dim[0].size, + len(expected_names)) + + self.assertEqual(save_shape_op.attr['value'].tensor.string_val[:], + expected_shapes) + self.assertEqual( + save_shape_op.attr['value'].tensor.tensor_shape.dim[0].size, + len(expected_shapes)) + self.assertEqual( + save_shape_op.attr['_output_shapes'].list.shape[0].dim[0].size, + len(expected_shapes)) + + @test.mock.patch.object(meta_graph_transform, '_clean_save_and_restore') + @test.mock.patch.object(meta_graph_transform, '_gtt_transforms') + def test_sparsify_gather_transform(self, gtt_mock, clean_save_restore_mock): + # Initial graph def. + graph_def = graph_pb2.GraphDef() + variable_op = graph_def.node.add() + variable_op.name = '/foo/part_1' + + constant_op = graph_def.node.add() + constant_op.name = '/bar' + + # Transformed graph def. + transformed_graph_def = graph_pb2.GraphDef() + constant_op = transformed_graph_def.node.add() + constant_op.name = '/foo' + + sparsify_shared_init_op_name = 'sparify_gather_init_op' + new_table_init_names = ['table1', 'table2'] + init_op = transformed_graph_def.node.add() + init_op.name = sparsify_shared_init_op_name + init_op.input.extend(['^' + f for f in new_table_init_names]) + + saver_op = transformed_graph_def.node.add() + saver_op.name = 'save_1/SaveV2' + + orig_table_init_names = ['orig_table_init_1', 'orig_table_init_2'] + + legacy_op_name = 'legacy_op' + legacy_op = transformed_graph_def.node.add() + legacy_op.name = legacy_op_name + legacy_op.input.extend(['^' + f for f in orig_table_init_names]) + + input_names = ['i1', 'i2'] + output_names = ['o1', 'o2'] + + initializer_names = { + 'foo_init': ['init1', 'init2'], + ops.GraphKeys.TABLE_INITIALIZERS: orig_table_init_names, + saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op_name] + } + checkpoint_path = '/path/to/checkpoint' + + expected_initializer_names = { + 'foo_init': ['init1', 'init2'], + ops.GraphKeys.TABLE_INITIALIZERS: ( + orig_table_init_names + new_table_init_names), + saved_model_constants.LEGACY_INIT_OP_KEY: [legacy_op_name] + } + + expected_sparsify_cmd = [ + 'sparsify_gather(input_checkpoint="%s", group_init_node="%s")' % + (checkpoint_path, sparsify_shared_init_op_name) + ] + + # Expected graph def. + expected_graph_def = graph_pb2.GraphDef() + constant_op = expected_graph_def.node.add() + constant_op.name = '/foo' + + saver_op = expected_graph_def.node.add() + saver_op.name = 'save_1/SaveV2' + + legacy_op_name = 'legacy_op' + legacy_op = expected_graph_def.node.add() + legacy_op.name = legacy_op_name + legacy_op.input.extend( + ['^' + f for f in orig_table_init_names + new_table_init_names]) + + gtt_mock.return_value = transformed_graph_def + graph_def_result, init_names_result = ( + meta_graph_transform._sparsify_gather_transform( + graph_def, input_names, output_names, initializer_names, + checkpoint_path)) + + gtt_mock.assert_called_once_with(graph_def, input_names, output_names, + initializer_names, expected_sparsify_cmd) + + clean_save_restore_mock.assert_called_once_with(transformed_graph_def, + saver_op, ['/bar', '/foo']) + + self.assertEqual(expected_graph_def, graph_def_result) + self.assertEqual(expected_initializer_names, init_names_result) + + @test.mock.patch.object(meta_graph_transform, '_gtt_transforms') + @test.mock.patch.object(meta_graph_transform, '_freeze_transform') + @test.mock.patch.object(meta_graph_transform, '_sparsify_gather_transform') + def test_do_transforms(self, sparsify_mock, freeze_mock, gtt_mock): + graph_def = graph_pb2.GraphDef() + constant_op = graph_def.node.add() + constant_op.name = 'c1' + + input_names = ['i1', 'i2'] + output_names = ['o1', 'o2'] + initializer_names = { + 'foo_init': ['init1', 'init2'], + ops.GraphKeys.TABLE_INITIALIZERS: ['table1'], + saved_model_constants.LEGACY_INIT_OP_KEY: ['legacy_op'] + } + + transforms = ['foo', 'freeze_graph', 'bar', 'sparsify_gather', 'baz'] + + sparsify_mock.return_value = (graph_def, initializer_names) + freeze_mock.return_value = (graph_def, initializer_names) + gtt_mock.return_value = graph_def + + graph_def_result, initializer_names_result = ( + meta_graph_transform._do_transforms(graph_def, input_names, + output_names, initializer_names, + transforms)) + + sparsify_mock.assert_called_once_with(graph_def, input_names, output_names, + initializer_names, None) + + freeze_mock.assert_called_once_with(graph_def, output_names, + initializer_names, None, None) + + gtt_mock.assert_has_calls([ + test.mock.call(graph_def, input_names, output_names, initializer_names, + ['foo']), + test.mock.call(graph_def, input_names, output_names, initializer_names, + ['bar']), + test.mock.call(graph_def, input_names, output_names, initializer_names, + ['baz']) + ]) + self.assertEqual(graph_def_result, graph_def) + self.assertEqual(initializer_names, initializer_names_result) + + def test_add_new_inits_to_collection(self): + meta_graph_def = meta_graph_pb2.MetaGraphDef() + + orig_table_inits = ['t1', 't2'] + new_table_inits = ['t3', 't4'] + + meta_graph_def.collection_def[ + ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend( + orig_table_inits) + updated_init_names = { + ops.GraphKeys.TABLE_INITIALIZERS: orig_table_inits + new_table_inits + } + + meta_graph_transform._add_new_inits_to_collection(meta_graph_def, + updated_init_names) + + self.assertEqual(meta_graph_def.collection_def[ + ops.GraphKeys.TABLE_INITIALIZERS].node_list.value, + orig_table_inits + new_table_inits) + + @test.mock.patch.object(graph_transforms, 'TransformGraph') + @test.mock.patch.object(meta_graph_transform, '_freeze_graph_with_def_protos') + def test_freeze_then_sparsify(self, freeze_mock, graph_transform_mock): tag_name = 'tag' input_nodes = 'input_nodes' output_nodes = 'output_nodes' freeze_transform = 'freeze_graph' - sparsify_transform = 'sparsify_graph' + sparsify_transform = 'sparsify_gather' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() @@ -141,8 +421,9 @@ class MetaGraphTransformTest(test.TestCase): base_meta_graph_def.graph_def, [output_nodes], [table_init_name], group_deps_name, base_meta_graph_def.saver_def, None) graph_transform_mock.assert_called_once_with( - transformed_graph_def, [input_nodes], - [output_nodes, group_deps_name, table_init_name], [sparsify_transform]) + transformed_graph_def, [input_nodes], [ + output_nodes, group_deps_name, table_init_name + ], [sparsify_transform + '(group_init_node="sparify_gather_init_op")']) def test_connect_to_shared_init_op(self): group_deps_name = 'group_deps' @@ -166,19 +447,20 @@ class MetaGraphTransformTest(test.TestCase): self.assertEqual(expected_graph_def_2, orig_graph_def) def test_add_pruned_collection_node(self): + # Note: This also tests _is_removed(). collection_name = 'node_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].node_list.value.extend( - ['node1', 'node2', 'node3', 'node4']) + ['node1', 'node2', 'node3', 'node4', '/a/a_1', '/b/b_1']) meta_graph_def = meta_graph_pb2.MetaGraphDef() - removed_op_names = ['node2', 'node4', 'node5'] + removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] - expected_nodes = ['node1', 'node3'] + expected_nodes = ['node1', 'node3', '/a/a_1'] self.assertEqual(expected_nodes, collection.node_list.value) def test_add_pruned_collection_int(self): @@ -188,7 +470,7 @@ class MetaGraphTransformTest(test.TestCase): [10, 20, 30, 40]) meta_graph_def = meta_graph_pb2.MetaGraphDef() - removed_op_names = ['node2', 'node4', 'node5'] + removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) @@ -198,36 +480,47 @@ class MetaGraphTransformTest(test.TestCase): self.assertEqual(expected_ints, collection.int64_list.value) def test_add_pruned_collection_proto_in_any_list(self): + # Note: This also tests _is_removed_mentioned(). collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() - base_meta_graph_def.collection_def[collection_name].any_list.value.extend( - [_make_asset_file_def_any('node1'), - _make_asset_file_def_any('node2'), - _make_asset_file_def_any('node3'), - _make_asset_file_def_any('node4')]) + base_meta_graph_def.collection_def[collection_name].any_list.value.extend([ + _make_asset_file_def_any('node1'), + _make_asset_file_def_any('node2'), + _make_asset_file_def_any('node3'), + _make_asset_file_def_any('node4'), + _make_asset_file_def_any('/a/a_1'), + _make_asset_file_def_any('/b/b_1') + ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() - removed_op_names = ['node2', 'node4', 'node5'] + removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) collection = meta_graph_def.collection_def[collection_name] - expected_protos = [_make_asset_file_def_any('node1'), - _make_asset_file_def_any('node3')] + expected_protos = [ + _make_asset_file_def_any('node1'), + _make_asset_file_def_any('node3'), + _make_asset_file_def_any('/a/a_1'), + ] self.assertEqual(expected_protos, collection.any_list.value[:]) def test_add_pruned_collection_proto_in_bytes_list(self): + # Note: This also tests _is_removed_mentioned(). collection_name = 'proto_collection' base_meta_graph_def = meta_graph_pb2.MetaGraphDef() base_meta_graph_def.collection_def[collection_name].bytes_list.value.extend( [compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node2'))), compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))), - compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4')))]) + compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node4'))), + compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))), + compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/b/b_1'))) + ]) meta_graph_def = meta_graph_pb2.MetaGraphDef() - removed_op_names = ['node2', 'node4', 'node5'] + removed_op_names = ['node2', 'node4', 'node5', '/a', '/b/b_1'] meta_graph_transform._add_pruned_collection( base_meta_graph_def, meta_graph_def, collection_name, removed_op_names) @@ -235,7 +528,9 @@ class MetaGraphTransformTest(test.TestCase): expected_values = [ compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node1'))), - compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3')))] + compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('node3'))), + compat.as_bytes(compat.as_str_any(_make_asset_file_def_any('/a/a_1'))), + ] self.assertEqual(expected_values, collection.bytes_list.value[:]) def test_add_pruned_saver(self): |