aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
blob: f37a2593e2695ec55a10f69803cc7f446b0862df (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Apply graph_transforms tool to MetaGraphDefs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import re as _re

from tensorflow.core.framework import graph_pb2 as _graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2 as _meta_graph_pb2
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _graph_util
from tensorflow.python.framework import importer as _importer
from tensorflow.python.framework import ops as _ops
from tensorflow.python.saved_model import constants as _saved_model_constants
from tensorflow.python.training import saver as _saver_lib
from tensorflow.python.util import compat as _compat
from tensorflow.tools import graph_transforms as _graph_transforms


_FREEZE_GRAPH_TRANSFORM = 'freeze_graph'
_SPARSIFY_GATHER_TRANSFORM = 'sparsify_gather'


def _op_name(tensor_name):
  """Get the op name from a tensor name."""
  # control dependency inputs start with ^
  if tensor_name[0] == '^':
    tensor_name = tensor_name[1:]
  if ':' in tensor_name:
    op_name, _ = tensor_name.split(':')
    return op_name
  return tensor_name


def _get_shared_init_op(initializer_names):
  """Obtain the shared init op name, if it exists.

  Args:
   initializer_names: Dictionary of the "infrastructural" nodes (initializers,
     save and restore ops, etc.). The keys in this dictionary
     indicate the collection where these nodes were obtained from.

  Returns:
    A string indicating the shared init op name or none if None if none exists.
  """
  return_value = initializer_names.get(_saved_model_constants.MAIN_OP_KEY, None)
  if not return_value:
    return_value = initializer_names.get(
        _saved_model_constants.LEGACY_INIT_OP_KEY, None)
  return str(return_value[0]) if return_value else None


def _gtt_transforms(graph_def, input_names, output_names, initializer_names,
                    transforms):
  """Pass through gtt transforms, applying them to the graph_def.

  Args:
    graph_def: A GraphDef proto to be transformed.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
      save and restore ops, etc.) that should be retained even if they are not
      transitively reachable from output nodes. The keys in this dictionary
      indicate the collection where these nodes were obtained from.
    transforms: A list of strings naming the graph transforms to be applied in
      order.
  Returns:
    The transformed GraphDef.
  """
  if not transforms:
    transformed_graph_def = _graph_pb2.GraphDef()
    transformed_graph_def.CopyFrom(graph_def)
    return transformed_graph_def

  initializer_names_flat = sorted(
      [k for l in initializer_names.values() for k in l])
  all_output_names = output_names + initializer_names_flat
  return _graph_transforms.TransformGraph(graph_def, input_names,
                                          all_output_names, transforms)


def _freeze_transform(graph_def, output_names, initializer_names, saver_def,
                      checkpoint_path):
  """Handle the freeze transform.

  Determine which initializer nodes should be retained by the freeze transform.
  Retain those nodes and return an updated dictionary containing them.

  Args:
    graph_def: A GraphDef proto to be transformed.
    output_names: Names of output nodes.
    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
      save and restore ops, etc.). The keys in this dictionary
      indicate the collection where these nodes were obtained from.
    saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
      if needed (default None).
    checkpoint_path:  A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A tuple containing the GraphDef and a Dict of pruned initializer nodes.
  """
  table_initializers = initializer_names.get(_ops.GraphKeys.TABLE_INITIALIZERS,
                                             [])
  shared_init_op = _get_shared_init_op(initializer_names)

  graph_def = _freeze_graph_with_def_protos(graph_def, output_names,
                                            table_initializers, shared_init_op,
                                            saver_def, checkpoint_path)
  pruned_initializer_names = {}
  # Freeze graph prunes all initializers and shared init nodes that are not
  # explicitly maintained. Create new initializer_names dictionary to reflect
  # this.
  if table_initializers:
    pruned_initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = (
        table_initializers)
    if _saved_model_constants.LEGACY_INIT_OP_KEY in initializer_names:
      pruned_initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY] = (
          initializer_names[_saved_model_constants.LEGACY_INIT_OP_KEY])
    if _saved_model_constants.MAIN_OP_KEY in initializer_names:
      pruned_initializer_names[_saved_model_constants.MAIN_OP_KEY] = (
          initializer_names[_saved_model_constants.MAIN_OP_KEY])
  return (graph_def, pruned_initializer_names)


def _clean_save_and_restore(graph_def, op, removed_op_names):
  """Clean the specified save and restore op.

  Updates the dtypes attribute of the save / restore op and the associated name
  and shape tensors to remove entries for variables that have been removed.

  Args:
    graph_def: A GraphDef proto to be transformed.
    op: The save or restore op to update.
    removed_op_names: List of op names that have been removed.
  """
  name = op.name + '/tensor_names'
  shape = op.name + '/shape_and_slices'
  name_op = _find_op(graph_def, name)
  shape_op = _find_op(graph_def, shape)
  name_op_value_tensor = name_op.attr['value'].tensor
  shape_op_value_tensor = shape_op.attr['value'].tensor
  names = []
  shapes = []
  dtypes = []
  for index, value in enumerate(name_op_value_tensor.string_val):
    if not _is_removed(_compat.as_str(value), removed_op_names):
      names.append(value)
      shapes.append(shape_op_value_tensor.string_val[index])
      dtypes.append(op.attr['dtypes'].list.type[index])
  name_op_value_tensor.string_val[:] = names
  name_op_value_tensor.tensor_shape.dim[0].size = len(names)
  shape_op_value_tensor.string_val[:] = shapes
  shape_op_value_tensor.tensor_shape.dim[0].size = len(shapes)
  op.attr['dtypes'].list.type[:] = dtypes

  if not name_op.attr['_output_shapes'].list.shape:
    name_op.attr['_output_shapes'].list.shape.add()
    name_op.attr['_output_shapes'].list.shape[0].dim.add()
  name_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(names)

  if not shape_op.attr['_output_shapes'].list.shape:
    shape_op.attr['_output_shapes'].list.shape.add()
    shape_op.attr['_output_shapes'].list.shape[0].dim.add()
  shape_op.attr['_output_shapes'].list.shape[0].dim[0].size = len(shapes)


def _sparsify_gather_transform(graph_def, input_names, output_names,
                               initializer_names, checkpoint_path):
  """Handle the sparsify gather transform.

  Provides the transform the checkpoint and keeps track of the newly created
  initializer nodes.

  Args:
    graph_def: A GraphDef proto to be transformed.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
      save and restore ops, etc.). The keys in this dictionary
      indicate the collection where these nodes were obtained from.
    checkpoint_path:  A path to a checkpoint.

  Returns:
    A tuple containing the GraphDef and a Dict of updated initializer nodes.
  Raises:
    ValueError: if the restore_op_name does not have the expected format.
  """
  # Ensure that sparsify_shared_init_op is unique.
  sparsify_shared_init_op = 'sparify_gather_init_op'
  while _find_op(graph_def, sparsify_shared_init_op):
    sparsify_shared_init_op += '_1'

  input_flag = ''
  if checkpoint_path:
    input_flag = 'input_checkpoint="%s", ' % checkpoint_path

  sparsify_cmd = [
      'sparsify_gather(%sgroup_init_node="%s")' % (input_flag,
                                                   sparsify_shared_init_op)
  ]

  starting_op_names = [node.name for node in graph_def.node]

  graph_def = _gtt_transforms(graph_def, input_names, output_names,
                              initializer_names, sparsify_cmd)
  ending_op_names = [node.name for node in graph_def.node]
  removed_op_names = list(set(starting_op_names) - set(ending_op_names))
  removed_op_names.sort()

  for op_index, op_name in enumerate(removed_op_names):
    op_name_parts = op_name.rsplit('/', 1)
    # Remove part to get the checkpoint names used by the saver.
    if len(op_name_parts) == 2 and op_name_parts[1].startswith('part_'):
      removed_op_names[op_index] = op_name_parts[0]
    else:
      removed_op_names[op_index] = op_name

  # Obtain newly created table inits from gtt sparsify transform.
  added_table_inits = []
  for index, node in enumerate(graph_def.node):
    if node.name == sparsify_shared_init_op:
      added_table_inits = [n.lstrip('^') for n in node.input]

      table_initializers = initializer_names.get(
          _ops.GraphKeys.TABLE_INITIALIZERS, [])
      table_initializers.extend(added_table_inits)
      initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers

      del graph_def.node[index]
      break

  # Add inits to existing shared init op.
  node = _find_op(graph_def, _get_shared_init_op(initializer_names))
  for init in added_table_inits:
    node.input.append('^' + init)

  # Update saver.
  for node in graph_def.node:
    if node.name.endswith('SaveV2'):
      _clean_save_and_restore(graph_def, node, removed_op_names)

  return (graph_def, initializer_names)


def _do_transforms(graph_def,
                   input_names,
                   output_names,
                   initializer_names,
                   transforms,
                   saver_def=None,
                   checkpoint_path=None):
  """Apply requested transforms to a GraphDef, including freezing.

  Args:
    graph_def: A GraphDef proto to be transformed.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    initializer_names: Dictionary of the "infrastructural" nodes (initializers,
      save and restore ops, etc.) that should be retained even if they are not
      transitively reachable from output nodes. The keys in this dictionary
      indicate the collection where these nodes were obtained from.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' and
      'sparsify_gather' transforms.
    saver_def: A SaverDef proto used for restoring a checkpoint during freezing,
      if needed (default None).
    checkpoint_path:  A path to a checkpoint to restore during freezing,
      if needed (default None).
  Returns:
     A tuple containing the GraphDef and a Dict of updated initializer nodes.
  """
  transformed_graph_def = _graph_pb2.GraphDef()
  transformed_graph_def.CopyFrom(graph_def)
  transformed_initializer_names = initializer_names.copy()

  if not transforms:
    return transformed_graph_def, transformed_initializer_names

  current_gtt_transforms = []
  for t in transforms:
    if t == _FREEZE_GRAPH_TRANSFORM:
      transformed_graph_def = _gtt_transforms(
          transformed_graph_def, input_names, output_names,
          transformed_initializer_names, current_gtt_transforms)
      output_node_names = [_op_name(x) for x in output_names]
      transformed_graph_def, transformed_initializer_names = _freeze_transform(
          transformed_graph_def, output_node_names,
          transformed_initializer_names, saver_def, checkpoint_path)
      current_gtt_transforms = []
    elif t == _SPARSIFY_GATHER_TRANSFORM:
      transformed_graph_def = _gtt_transforms(
          transformed_graph_def, input_names, output_names,
          transformed_initializer_names, current_gtt_transforms)
      transformed_graph_def, transformed_initializer_names = (
          _sparsify_gather_transform(
              transformed_graph_def, input_names, output_names,
              transformed_initializer_names, checkpoint_path))
      current_gtt_transforms = []
    else:
      current_gtt_transforms.append(t)

  transformed_graph_def = _gtt_transforms(
      transformed_graph_def, input_names, output_names,
      transformed_initializer_names, current_gtt_transforms)
  return transformed_graph_def, transformed_initializer_names


def _connect_to_shared_init_op(graph_def, shared_init_op_name,
                               nodes_to_connect):
  """Creates a new shared init node that is connected to via control deps.

  Args:
    graph_def: The GraphDef proto to add the shared init node to.
    shared_init_op_name: A string specifying the name of the shared init node to
      create.
    nodes_to_connect: A list of strings specifying the names of nodes to connect
      to the shared node via control dependencies.
  """
  if nodes_to_connect:
    init_op = graph_def.node.add()
    init_op.name = shared_init_op_name
    init_op.op = 'NoOp'
    init_op.input.extend(['^' + i for i in nodes_to_connect])


# forked and modified from freeze_graph.py
def _freeze_graph_with_def_protos(input_graph_def, output_node_names,
                                  initializer_names, shared_init_op_name,
                                  input_saver_def, input_checkpoint):
  """Converts all variables in a graph and checkpoint into constants.

  During this process, we need to retain certain initializer nodes (e.g. table
  initializer nodes). Instead of determining which dependencies
  of the shared initializer node (e.g. group_deps) to keep, we
  reconstruct the connections between the individual initializer nodes and
  the shared node after freezing the graph.

  Args:
    input_graph_def: A GraphDef proto to be frozen.
    output_node_names: Names of output nodes.
    initializer_names: Names of initializer nodes to keep.
    shared_init_op_name: The name of the shared initializer node to connect the
      nodes in initializer names to.
    input_saver_def: A SaverDef proto used for restoring a checkpoint.
    input_checkpoint: A path to a checkpoint to restore.

  Returns:
    A frozen GraphDef.
  """

  with _ops.Graph().as_default():
    _ = _importer.import_graph_def(input_graph_def, name='')

    with _session.Session() as sess:
      saver = _saver_lib.Saver(saver_def=input_saver_def)
      saver.restore(sess, input_checkpoint)
      output_graph_def = _graph_util.convert_variables_to_constants(
          sess, input_graph_def, output_node_names + initializer_names)
      _connect_to_shared_init_op(output_graph_def, shared_init_op_name,
                                 initializer_names)
  return output_graph_def


def _find_all_mandatory_retain_ops(base_meta_graph_def):
  """Identify all infrastructural Ops, to ensure that they are retained.

  We need to retain infrastructural Ops (init and saver stuff), in addition
  to the desired outputs.

  For now we retain *all* save and restore ops, variable initializers,
  table initializers, and main init ops.
  This means that strip_unused_nodes will not remove unused variables.

  Args:
    base_meta_graph_def: a GraphDef proto in which to identify nodes to retain.

  Returns:
    A dictionary corresponding to the nodes associated with each collection
    that are to be retained.
  """
  # TODO(b/63447631): implement variable stripping.

  initializer_names = {}

  # Primary SaverDef and SAVERS collection
  saver_defs = []
  if base_meta_graph_def.HasField('saver_def'):
    saver_defs.append(base_meta_graph_def.saver_def)
  saver_defs.extend(_get_all_protos_from_collection(
      base_meta_graph_def, _ops.GraphKeys.SAVERS))
  for saver_def in saver_defs:
    savers = initializer_names.get(_ops.GraphKeys.SAVERS, [])
    savers.extend([
        saver_def.filename_tensor_name, saver_def.save_tensor_name,
        saver_def.restore_op_name
    ])
    initializer_names[_ops.GraphKeys.SAVERS] = savers

  # Variable initializers
  variable_collections = [
      _ops.GraphKeys.GLOBAL_VARIABLES,
      _ops.GraphKeys.TRAINABLE_VARIABLES,
      _ops.GraphKeys.MOVING_AVERAGE_VARIABLES,
      _ops.GraphKeys.LOCAL_VARIABLES,
      _ops.GraphKeys.MODEL_VARIABLES]
  for var_coll in variable_collections:
    variables = _get_all_protos_from_collection(base_meta_graph_def, var_coll)
    var_init_names = [v.initializer_name for v in variables]
    if var_init_names:
      # Sanity check to ensure we don't overwrite dictionary entries.
      assert var_coll not in initializer_names
      initializer_names[var_coll] = var_init_names

  # Table initializers
  op_names = _get_all_node_names_from_collection(
      base_meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
  if op_names:
    # Sanity check to ensure we don't overwrite dictionary entries.
    assert _ops.GraphKeys.TABLE_INITIALIZERS not in initializer_names
    table_initializers = [t for t in op_names]
    initializer_names[_ops.GraphKeys.TABLE_INITIALIZERS] = table_initializers

  # Various init ops
  various_init_op_collections = [_saved_model_constants.LEGACY_INIT_OP_KEY,
                                 _saved_model_constants.MAIN_OP_KEY,
                                 _ops.GraphKeys.INIT_OP,
                                 _ops.GraphKeys.LOCAL_INIT_OP,
                                 _ops.GraphKeys.READY_OP,
                                 _ops.GraphKeys.READY_FOR_LOCAL_INIT_OP]
  for op_coll in various_init_op_collections:
    op_name = _get_single_node_name_from_collection(
        base_meta_graph_def, op_coll)
    if op_name:
      # Sanity check to ensure we don't overwrite dictionary entries.
      assert op_coll not in initializer_names
      initializer_names[op_coll] = [op_name]
  return initializer_names


def _add_pruned_collection(base_meta_graph_def, meta_graph_def,
                           collection_name, removed_op_names):
  """Copy collection to the transformed MetaGraphDef, omitting removed items."""

  base_collection = base_meta_graph_def.collection_def[collection_name]
  collection = meta_graph_def.collection_def[collection_name]

  if base_collection.HasField('any_list'):
    for any_value in base_collection.any_list.value:
      # just search the serialized proto as a string
      if not _is_removed_mentioned(any_value.value, removed_op_names):
        copied_any = collection.any_list.value.add()
        copied_any.CopyFrom(any_value)
  elif base_collection.HasField('bytes_list'):
    collection.bytes_list.value[:] = [
        s for s in base_collection.bytes_list.value
        if not _is_removed_mentioned(s, removed_op_names)]
  elif base_collection.HasField('node_list'):
    collection.node_list.value[:] = [
        s for s in base_collection.node_list.value
        if not _is_removed(s, removed_op_names)]
  else:
    collection.CopyFrom(base_collection)


def _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names):
  """Copy the Saver into the transformed MetaGraphDef, if valid.

  Currently this copies the Saver as is, after verifying that none of the
  referenced Save & Restore ops were removed.  A future version will modify
  the Save and Restore ops themselves as needed to account for removed
  Variables.

  Args:
    base_meta_graph_def: The untransformed MetaGraphDef.
    meta_graph_def: The transformed MetaGraphDef being built.
    removed_op_names: An iterable of names of ops that were removed.
  """

  # Note this does surgery on meta_graph_def.graph_def too, so that should have
  # been copied already.
  if base_meta_graph_def.HasField('saver_def'):
    filename_tensor_name = base_meta_graph_def.saver_def.filename_tensor_name
    save_tensor_name = base_meta_graph_def.saver_def.save_tensor_name
    restore_op_name = base_meta_graph_def.saver_def.restore_op_name

    _check_tensor_not_removed(filename_tensor_name, removed_op_names)
    _check_tensor_not_removed(save_tensor_name, removed_op_names)
    _check_tensor_not_removed(restore_op_name, removed_op_names)

    # TODO(b/63447631): Once we strip unused variables, remove references to
    # them from save and restore ops.  Retain those ops only if they also refer
    # to retained Variables. See if we can use _clean_save_and_restore() for
    # this.

    # saver_name, restore_all = restore_op_name.rsplit('/', 1)
    # if restore_all != 'restore_all':
    #   raise ValueError(
    #       'SaverDef restore_op_name did not have expected form */restore_all')

    # save_tensor_names_op_name = '{}/SaveV2/tensor_names'.format(saver_name)
    # restore_tensor_names_op_name = (
    #     '{}/RestoreV2/tensor_names'.format(saver_name))

    # save_tensor_names_op = _find_op(meta_graph_def.graph_def,
    #                                 save_tensor_names_op_name)
    # save_tensor_names_value_tensor = save_tensor_names_op.attr['value'].tensor
    # save_tensor_names_value_tensor.string_val[:] = [
    #     s for s in save_tensor_names_value_tensor.string_val
    #     if not _is_removed(s, removed_op_names)]

    # restore_tensor_names_op = _find_op(
    #     meta_graph_def.graph_def, restore_tensor_names_op_name)
    # restore_tensor_names_value_tensor = (
    #     restore_tensor_names_op.attr['value'].tensor)
    # restore_tensor_names_value_tensor.string_val[:] = [
    #     s for s in restore_tensor_names_value_tensor.string_val
    #     if not _is_removed(s, removed_op_names)]

    # if (save_tensor_names_value_tensor.string_val
    #     or restore_tensor_names_value_tensor.string_val):
    meta_graph_def.saver_def.CopyFrom(base_meta_graph_def.saver_def)


def _find_op(graph_def, op_name):
  """Fetch a node from a GraphDef proto by name."""
  for node_def in graph_def.node:
    if node_def.name == op_name:
      return node_def
  return None


def _add_pruned_signature(base_meta_graph_def, meta_graph_def,
                          signature_name, removed_op_names):
  """Copy the named signature into the transformed MetaGraphDef, if valid.

  If any input or output mentioned in the signature was removed by the graph
  transform, the signature is silently omitted from the transformed
  MetaGraphDef.

  Args:
    base_meta_graph_def: The untransformed MetaGraphDef.
    meta_graph_def: The transformed MetaGraphDef being built.
    signature_name: The name of the signature to copy.
    removed_op_names: An iterable of names of ops that were removed.
  """
  try:
    base_signature = base_meta_graph_def.signature_def[signature_name]
    for key in base_signature.inputs:
      _check_tensor_not_removed(base_signature.inputs[key].name,
                                removed_op_names)
    for key in base_signature.outputs:
      _check_tensor_not_removed(base_signature.outputs[key].name,
                                removed_op_names)
    meta_graph_def.signature_def[signature_name].CopyFrom(base_signature)
  except ValueError:
    # exclude any signature that mentions a removed node
    pass


def _get_single_node_name_from_collection(meta_graph_def, collection_key):
  """Obtain a node name that is the single element of a collection."""
  if collection_key not in meta_graph_def.collection_def:
    return None
  collection = meta_graph_def.collection_def[collection_key]
  if not collection.node_list.value:
    raise ValueError(
        'Collection {} is present but type is not node_list.'.format(
            collection_key))
  if len(collection.node_list.value) != 1:
    raise ValueError(
        'Collection {} is has {} elements; expected exactly one.'.format(
            collection_key, collection.bytes_list))
  return collection.node_list.value[0]


def _get_all_node_names_from_collection(meta_graph_def, collection_key):
  """Obtain node names from a collection."""
  if collection_key not in meta_graph_def.collection_def:
    return None
  collection = meta_graph_def.collection_def[collection_key]
  if not collection.node_list.value:
    raise ValueError(
        'Collection {} is present but type is not node_list.'.format(
            collection_key))
  return collection.node_list.value


def _get_all_protos_from_collection(meta_graph_def, collection_key):
  """Obtain node names from a collection."""
  if collection_key not in meta_graph_def.collection_def:
    return []
  collection = meta_graph_def.collection_def[collection_key]
  if not collection.bytes_list.value:
    raise ValueError(
        'Collection {} is present but type is not bytes_list.'.format(
            collection_key))
  proto_type = _ops.get_collection_proto_type(collection_key)
  result = []
  for value in collection.bytes_list.value:
    proto = proto_type()
    proto.ParseFromString(value)
    result.append(proto)
  return result


def _is_removed(tensor_name, removed_op_names):
  """Determine whether the named tensor is an output of a removed op."""
  for removed_op_name in removed_op_names:
    if tensor_name.split(':')[0] == removed_op_name:
      return True
  return False


def _is_removed_mentioned(s, removed_op_names):
  """Determine whether any removed op is mentioned in the given object.

  This relies on the string representation of the object.  This is used for
  proto messages that may mention ops by name in nested fields.  The string
  representation of the proto includes those field values, so this string
  search approach is sufficient.

  Args:
    s: an object to search for removed op names.
    removed_op_names: An iterable of names of ops that were removed.

  Returns:
    True if any removed op is mentioned in the given object, False otherwise.
  """
  # A common approach taken by some of the transforms in gtt is to add new nodes
  # that have the same prefix as the node they are removing. For example, if
  # the original node name was /foo, they may remove that node and add in
  # /foo/bar. This regex ensures that we handle these two nodes
  # as separate entities.  It matches on nodes having names in the form of
  # '/foo/bar_x' as well as nodes having names in the form of 'foo.'
  s_names = _re.findall(r'((?:[\/]?[a-zA-Z0-9\_]*)*)', _compat.as_str_any(s))
  for removed_op_name in removed_op_names:
    for s_name in s_names:
      if s_name.endswith(removed_op_name):
        return True
  return False


def _check_tensor_not_removed(tensor_name, removed_op_names):
  """Verify that the named tensor was not removed.

  Args:
    tensor_name: the name of a tensor to check.
    removed_op_names: An iterable of names of ops that were removed.

  Raises:
    ValueError: if the tensor was removed.
  """
  if not tensor_name:
    raise ValueError('Tensor name should not be empty')
  if _is_removed(tensor_name, removed_op_names):
    raise ValueError(
        'Expected Tensor, but it was removed: {}'.format(tensor_name))


def _add_new_inits_to_collection(meta_graph_def, updated_initializer_names):
  """Add new inits to collection.

  Args:
    meta_graph_def: The MetaGraphDef protocol buffer to update.
    updated_initializer_names: Dictionary of the updated "infrastructural" nodes
      (initializers, save and restore ops, etc.). The keys in this dictionary
      indicate the collection where these nodes were obtained from.

  Raises:
    ValueError: if the tensor was removed.
  """
  # TODO(dzats): Extend this to support all collections.
  if _ops.GraphKeys.TABLE_INITIALIZERS in updated_initializer_names:
    orig_table_inits = _get_all_node_names_from_collection(
        meta_graph_def, _ops.GraphKeys.TABLE_INITIALIZERS)
    orig_table_inits = orig_table_inits if orig_table_inits else []
    updated_table_inits = updated_initializer_names[
        _ops.GraphKeys.TABLE_INITIALIZERS]
    new_table_inits = list(set(updated_table_inits) - set(orig_table_inits))
    new_table_inits.sort()
    meta_graph_def.collection_def[
        _ops.GraphKeys.TABLE_INITIALIZERS].node_list.value.extend(
            new_table_inits)


def meta_graph_transform(
    base_meta_graph_def, input_names, output_names, transforms, tags,
    checkpoint_path=None):
  """Apply the Graph Transform tool to a MetaGraphDef.

  Args:
    base_meta_graph_def: A MetaGraphDef protocol buffer to transform.
    input_names: Names of input nodes.
    output_names: Names of output nodes.
    transforms: A list of strings naming the graph transforms to be applied in
      order.  These transform names are exactly those supported by the Graph
      Transform Tool, with the addition of the 'freeze_graph' and
      'sparsify_gather' transforms.
    tags: A list of tags with which to annotate the transformed MetaGraphDef.
    checkpoint_path: A path to a checkpoint to restore during freezing,
      if needed (default None).

  Returns:
    A new transformed MetaGraphDef protocol buffer.
  """
  meta_graph_def = _meta_graph_pb2.MetaGraphDef()

  initializer_names = _find_all_mandatory_retain_ops(base_meta_graph_def)

  transformed_graph_def, updated_initializer_names = _do_transforms(
      base_meta_graph_def.graph_def, input_names, output_names,
      initializer_names, transforms, base_meta_graph_def.saver_def,
      checkpoint_path)

  meta_graph_def.graph_def.CopyFrom(transformed_graph_def)
  meta_graph_def.meta_info_def.CopyFrom(base_meta_graph_def.meta_info_def)
  meta_graph_def.meta_info_def.ClearField('tags')
  for tag in tags:
    meta_graph_def.meta_info_def.tags.append(tag)

  base_op_names = [_compat.as_str(node.name)
                   for node in base_meta_graph_def.graph_def.node]
  retained_op_names = [_compat.as_str(node.name)
                       for node in meta_graph_def.graph_def.node]
  removed_op_names = set(base_op_names) - set(retained_op_names)

  # Copy saver, excluding any pruned nodes if graph was not frozen.
  # TODO(b/63447631): Revisit this once the problem is addressed. Currently
  # _add_pruned_saver assumes that the save and restore nodes have not been
  # removed but freeze_graph (correctly) removes them.
  if _FREEZE_GRAPH_TRANSFORM not in transforms:
    _add_pruned_saver(base_meta_graph_def, meta_graph_def, removed_op_names)

  # Copy collections, excluding any pruned nodes
  for collection_name in base_meta_graph_def.collection_def:
    _add_pruned_collection(
        base_meta_graph_def, meta_graph_def, collection_name,
        removed_op_names)

  # Append newly added initializers to collection.
  _add_new_inits_to_collection(meta_graph_def, updated_initializer_names)

  # Copy signature_defs, excluding any pruned nodes
  for signature_name in base_meta_graph_def.signature_def:
    _add_pruned_signature(
        base_meta_graph_def, meta_graph_def, signature_name,
        removed_op_names)

  return meta_graph_def