aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/transform.py
blob: 2234400fdcbf6989a2d4a74543f380a35cecef31 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Class to transform an subgraph into another.
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from copy import deepcopy
from functools import partial
from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import tf_logging as logging


__all__ = [
    "replace_t_with_placeholder_handler",
    "keep_t_if_possible_handler",
    "assign_renamed_collections_handler",
    "transform_op_if_inside_handler",
    "copy_op_handler",
    "Transformer",
    "TransformerInfo",
    "copy",
    "copy_with_input_replacements",
    "graph_replace",
]


def replace_t_with_placeholder_handler(info, t):
  """Transform a tensor into a placeholder tensor.

  This handler is typically used to transform a subgraph input tensor into a
  placeholder.

  Args:
    info: Transform._TmpInfo instance.
    t: tensor whose input must be transformed into a place holder.
  Returns:
    The tensor generated by the newly created place holder.
  """
  with info.graph_.as_default():
    t_ = util.make_placeholder_from_tensor(t, scope=info.scope_)
  return t_


def keep_t_if_possible_handler(info, t):
  """Transform a tensor into itself (identity) if possible.

  This handler transform a tensor into itself if the source and destination
  graph are the same. Otherwise it will create a placeholder.
  This handler is typically used to transform a hidden input tensors.

  Args:
    info: Transform._TmpInfo instance.
    t: tensor whose input must be transformed into a place holder.
  Returns:
    The tensor generated by the newly created place holder.
  """
  if info.graph is info.graph_:
    return t
  else:
    return replace_t_with_placeholder_handler(info, t)


def assign_renamed_collections_handler(info, elem, elem_):
  """Add the transformed elem to the (renamed) collections of elem.

  A collection is renamed only if is not a known key, as described in
  `tf.GraphKeys`.

  Args:
    info: Transform._TmpInfo instance.
    elem: the original element (`tf.Tensor` or `tf.Operation`)
    elem_: the transformed element
  """
  known_collection_names = util.get_predefined_collection_names()
  for name, collection in iteritems(info.collections):
    if elem not in collection:
      continue

    if name in known_collection_names:
      transformed_name = name
    else:
      transformed_name = info.new_name(name)
    info.graph_.add_to_collection(transformed_name, elem_)


def transform_op_if_inside_handler(info, op, keep_if_possible=True):
  """Transform an optional op only if it is inside the subgraph.

  This handler is typically use to handle original op: it is fine to keep them
  if they are inside the subgraph, otherwise they are just ignored.

  Args:
    info: Transform._TmpInfo instance.
    op: the optional op to transform (or ignore).
    keep_if_possible: re-attach to the original op if possible, that is,
      if the source graph and the destination graph are the same.
  Returns:
    The transformed op or None.
  """
  if op in info.sgv.ops:
    return info.transformed_ops[op]
  else:
    if keep_if_possible and info.graph is info.graph_:
      return op
    else:
      return None


def copy_op_handler(info, op, copy_shape=True):
  """Copy a `tf.Operation`.

  Args:
    info: Transform._TmpInfo instance.
    op: the `tf.Operation` to be copied.
    copy_shape: also copy the shape of the tensor
  Returns:
    A `(op, op_outputs)` tuple containing the transformed op and its outputs.
  """
  # pylint: disable=protected-access

  # Clone the node def:
  node_def_ = deepcopy(op._node_def)

  # Transform name:
  name_ = info.new_name(op.name)
  name_ = info.graph_.unique_name(name_)
  node_def_.name = name_

  # Copy the other inputs needed for initialization
  output_types_ = op._output_types[:]
  input_types_ = op._input_types[:]

  # Make a copy of the op_def too.
  # Its unique to every _type_ of Operation.
  op_def_ = deepcopy(op._op_def)

  # Initialize a new Operation instance
  op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_,
                         [], input_types_, None, op_def_)

  # copy the shape over
  if copy_shape:
    for t, t_ in zip(op.outputs, op_.outputs):
      t_.set_shape(t.get_shape())

  # Finalize original op.
  if op._original_op:
    original_op = info.transform_original_op_handler(info, op._original_op)
    if original_op is None:
      logging.debug("Could not find original op of: %s", op_.name)
    else:
      op_._original_op = original_op

  # Add op to the graph
  info.graph_._add_op(op_)

  return op_, op_.outputs


class TransformerInfo(object):
  """"Contains information about the result of a transform operation."""

  def __init__(self, info):
    """Constructor.

    Args:
      info: an instance of Transformer._TmpInfo containing various internal
        information about the transform operation.
    """
    self._graph = info.graph
    self._scope = info.scope
    self._graph_ = info.graph_
    self._scope_ = info.scope_
    self._transformed_ops = info.transformed_ops
    self._transformed_ts = info.transformed_ts

  def _get_transformed_map(self, top):
    """Return the correct container depending on the type of `top`."""
    if isinstance(top, tf_ops.Operation):
      return self._transformed_ops
    elif isinstance(top, tf_ops.Tensor):
      return self._transformed_ts
    else:
      raise TypeError(
          "Expected a tf.Tensor or a tf.Operation, got a {}".format(
              type(top)))

  def _transformed_elem(self, original_top, missing_fn=None):
    """Return the transformed op/tensor corresponding to the original one.

    Args:
      original_top: the original tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the transformed tensor/operation (or None if no match is found).
    """
    transformed_map = self._get_transformed_map(original_top)
    if isinstance(original_top, string_types):
      for original, transformed in iteritems(transformed_map):
        if original.name == original_top:
          return transformed
      return None if missing_fn is None else missing_fn(original_top)
    else:
      if original_top not in transformed_map:
        return None if missing_fn is None else missing_fn(original_top)
      return transformed_map[original_top]

  def _original_elem(self, transformed_top, missing_fn=None):
    """Return the original op/tensor corresponding to the transformed one.

    Args:
      transformed_top: the transformed tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the original tensor/operation (or None if no match is found).
    """
    transformed_map = self._get_transformed_map(transformed_top)
    if isinstance(transformed_top, string_types):
      finder = lambda transformed: transformed.name == transformed_top
    else:
      finder = lambda transformed: transformed == transformed_top
    for original, transformed in iteritems(transformed_map):
      if finder(transformed):
        return original
    return None if missing_fn is None else missing_fn(transformed_top)

  def transformed(self, original, missing_fn=None):
    """Return the transformed op/tensor corresponding to the original one.

    Note that the output of this function mimics the hierarchy
    of its input argument `original`.
    Given an iterable, it returns a list. Given an operation or a tensor,
    it will return an operation or a tensor.

    Args:
      original: the original tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the transformed tensor/operation (or None if no match is found).
    """
    transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
    return util.transform_tree(original, transformed_elem)

  def original(self, transformed, missing_fn=None):
    """Return the original op/tensor corresponding to the transformed one.

    Note that the output of this function mimics the hierarchy
    of its input argument `transformed`.
    Given an iterable, it returns a list. Given an operation or a tensor,
    it will return an operation or a tensor.

    Args:
      transformed: the transformed tensor/operation.
      missing_fn: function handling the case where the counterpart
        cannot be found. By default, None is returned.
    Returns:
      the original tensor/operation (or None if no match is found).
    """
    original_elem = partial(self._original_elem, missing_fn=missing_fn)
    return util.transform_tree(transformed, original_elem)

  def __str__(self):
    res = StringIO()
    print("Transform result info:", file=res)
    if self._graph == self._graph_:
      in_place_str = "" if self._scope_ else " IN-PLACE"
      print("  Within graph[{}]{}".format(
          id(self._graph), in_place_str), file=res)
    else:
      print("  graph[{}] => graph[{}]".format(
          id(self._graph), id(self._graph_)), file=res)
    if self._scope:
      print("  Relative to source scope: {}".format(self._scope), file=res)
    if self._scope_:
      print("  Scope destination: {}".format(self._scope_), file=res)
    print("Operations mapping:", file=res)
    for op, op_ in iteritems(self._transformed_ops):
      print("  {} => {}".format(op.name, op_.name), file=res)
    return res.getvalue()


class _TmpInfo(object):
  """Transformer temporary data.

  An instance of this class holds all the information relevant to a call
  to a transformer instance (that is, a call to __call__). An instance
  is created for the life-time of the __call__ function and is passed as
  argument to the handlers.
  """

  def __init__(self, sgv, dst_graph, dst_scope, src_scope):
    self.sgv = sgv
    self.sgv_inputs_set = frozenset(sgv.inputs)
    self.ops = frozenset(sgv.ops)
    self.control_outputs = util.ControlOutputs(sgv.graph)
    self.graph = sgv.graph
    self.scope = src_scope
    self.graph_ = dst_graph
    self.scope_ = dst_scope
    self.transformed_ops = {}
    self.transformed_ts = {}
    self.collections = dict((key, self.graph.get_collection(key))
                            for key in self.graph.get_all_collection_keys())
    self.cyclic_ops = []
    self.transform_original_op_handler = transform_op_if_inside_handler

  def new_name(self, name):
    """Compute a destination name from a source name.

    Args:
      name: the name to be "transformed".
    Returns:
      The transformed name.
    Raises:
      ValueError: if the source scope is used (that is, not an empty string)
        and the source name does not belong to the source scope.
    """
    scope = self.scope
    if not name.startswith(scope):
      raise ValueError("{} does not belong to source scope: {}.".format(
          name, scope))
    rel_name = name[len(scope):]
    name_ = self.scope_ + rel_name
    return name_


class Transformer(object):
  """Transform a subgraph into another one.

  By default, the constructor create a transform which copy a subgraph and
  replaces inputs with placeholders. This behavior can be modified by changing
  the handlers.
  """

  def __init__(self):
    """Transformer constructor.

    The following members can be modified:
    transform_op_handler: handle the transformation of a `tf.Operation`.
      This handler defaults to a simple copy.
    assign_collections_handler: handle the assignment of collections.
      This handler defaults to assigning new collections created under the
      given name-scope.
    transform_external_input_handler: handle the transform of the inputs to
      the given subgraph. This handler defaults to creating placeholders
      instead of the ops just before the input tensors of the subgraph.
    transform_external_hidden_input_handler: handle the transform of the
      hidden inputs of the subgraph, that is, the inputs which are not listed
      in sgv.inputs. This handler defaults to a transform which keep the same
      input if the source and destination graphs are the same, otherwise
      use placeholders.
    transform_original_op_handler: handle the transform of original_op. This
      handler defaults to transforming original_op only if they are in the
      subgraph, otherwise they are ignored.
    """

    # handlers
    self.transform_op_handler = copy_op_handler
    self.transform_control_input_handler = transform_op_if_inside_handler
    self.assign_collections_handler = assign_renamed_collections_handler
    self.transform_external_input_handler = replace_t_with_placeholder_handler
    self.transform_external_hidden_input_handler = keep_t_if_possible_handler
    self.transform_original_op_handler = transform_op_if_inside_handler

  def __call__(self,
               sgv,
               dst_graph,
               dst_scope,
               src_scope="",
               reuse_dst_scope=False):
    """Execute the transformation.

    Args:
      sgv: the source subgraph-view.
      dst_graph: the destination graph.
      dst_scope: the destination scope.
      src_scope: the source scope, which specify the path from which the
        relative path of the transformed nodes are computed. For instance, if
        src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
        relative path of x/y and will be transformed into b/x/y.
      reuse_dst_scope: if True the dst_scope is re-used if it already exists.
        Otherwise, the scope is given a unique name based on the one given
        by appending an underscore followed by a digit (default).
    Returns:
      A tuple `(sgv, info)` where:
        `sgv` is the transformed subgraph view;
        `info` is an instance of TransformerInfo containing
        information about the transform, including mapping between
        original and transformed tensors and operations.
    Raises:
      ValueError: if the arguments are invalid.
    """
    sgv = subgraph.make_view(sgv)
    if not isinstance(dst_graph, tf_ops.Graph):
      raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

    src_scope = util.scope_finalize(src_scope)
    dst_scope = util.scope_finalize(dst_scope)

    # Potentially create new scope if reuse_dst_scope is False
    if dst_scope and not reuse_dst_scope:
      dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))

    # Create temporary info used during this transform call
    info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
    info.transform_original_op_handler = self.transform_original_op_handler

    self._copy_ops(info)
    self._connect_ops(info)

    # Compute information about the transformation
    res_info = TransformerInfo(info)
    sgv_ = self._transform_sgv(info, sgv)
    return sgv_, res_info

  def _copy_ops(self, info):
    """Copy ops without connecting them."""
    for op in info.sgv.ops:
      logging.debug("Copying op: %s", op.name)
      # TODO(fkp): return a subgraph?
      op_, op_outputs_ = self.transform_op_handler(info, op)
      if op is op_:
        raise ValueError("In-place transformation not allowed.")

      # Process op.
      info.transformed_ops[op] = op_
      self.assign_collections_handler(info, op, op_)

      # Process output tensors.
      for op_output, op_output_ in zip(op.outputs, op_outputs_):
        info.transformed_ts[op_output] = op_output_
        self.assign_collections_handler(info, op_output, op_output_)

  def _connect_ops(self, info):
    """Connect the previously copied ops."""
    for op in info.sgv.ops:
      logging.debug("Finalizing op: %s", op.name)
      op_ = info.transformed_ops[op]

      # pylint: disable=protected-access
      if op_.inputs:
        raise ValueError("The newly transformed op should not have "
                         "any inputs yet: {}".format(op_.name))
      inputs_ = [self._transformed_t(info, t) for t in op.inputs]
      for t in inputs_:
        op_._add_input(t)

      # Finalize control inputs:
      control_inputs_ = [self.transform_control_input_handler(info, ci)
                         for ci in op.control_inputs]
      control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
      reroute.add_control_inputs(op_, control_inputs_)

  def _transform_sgv(self, info, sgv):
    """Transform a subgraph view.

    For convenience, a transform operation returns a subgraph view of the
    transformed graph.

    Args:
      info: Temporary information for this transorfm call.
      sgv: the subgraph to be transformed.
    Returns:
      The transformed subgraph.
    """
    ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
    sgv_ = subgraph.SubGraphView(ops_)
    sgv_inputs_ = sgv_.inputs
    sgv_outputs_ = sgv_.outputs

    # re-order inputs
    input_map_ = []
    for input_t in sgv.inputs:
      if input_t not in info.transformed_ts:
        continue
      input_t_ = info.transformed_ts[input_t]
      if input_t_ not in sgv_inputs_:
        continue
      input_t_index_ = sgv_.input_index(input_t_)
      input_map_.append(input_t_index_)

    # re-order outputs
    output_map_ = []
    for output_t in sgv.outputs:
      if output_t not in info.transformed_ts:
        continue
      output_t_ = info.transformed_ts[output_t]
      if output_t_ not in sgv_outputs_:
        continue
      output_t_index_ = sgv_.output_index(output_t_)
      output_map_.append(output_t_index_)

    return sgv_.remap(input_map_, output_map_)

  def _transformed_t(self, info, t):
    """Return tre transformed tensor of `t`."""
    if t not in info.transformed_ts:
      # If op is not in the subgraph.
      if t in info.sgv_inputs_set:
        # t is an input of the subgraph.
        return self.transform_external_input_handler(info, t)
      else:
        # t is a hidden input of the subgraph.
        return self.transform_external_hidden_input_handler(info, t)
    else:
      # If op is in the subgraph, just return its transformed.
      return info.transformed_ts[t]


def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
         reuse_dst_scope=False):
  """Copy a subgraph.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules than the function subgraph.make_view.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if `dst_graph` is not a `tf.Graph`.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules than the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  if dst_graph is None:
    dst_graph = sgv.graph
  if not isinstance(dst_graph, tf_ops.Graph):
    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

  copier = Transformer()
  return copier(
      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)


def copy_with_input_replacements(sgv, replacement_ts,
                                 dst_graph=None, dst_scope="", src_scope="",
                                 reuse_dst_scope=False):
  """Copy a subgraph, replacing some of its inputs.

  Note a replacement only happens if the tensor to be replaced
  is an input of the given subgraph. The inputs of a subgraph can
  be queried using sgv.inputs.

  Args:
    sgv: the source subgraph-view. This argument is converted to a subgraph
      using the same rules as the function subgraph.make_view.
    replacement_ts: dictionary mapping from original tensors to the
      replaced one.
    dst_graph: the destination graph.
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A tuple `(sgv, info)` where:
      `sgv` is the transformed subgraph view;
      `info` is an instance of TransformerInfo containing
      information about the transform, including mapping between
      original and transformed tensors and operations.
  Raises:
    TypeError: if dst_graph is not a tf.Graph.
    StandardError: if sgv cannot be converted to a SubGraphView using
      the same rules as the function subgraph.make_view.
  """
  sgv = subgraph.make_view(sgv)
  if dst_graph is None:
    dst_graph = sgv.graph
  if not isinstance(dst_graph, tf_ops.Graph):
    raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))

  copier = Transformer()
  # Replace tensor if possible.
  def replace_t_with_replacement_handler(info, t):
    if t in replacement_ts:
      return replacement_ts[t]
    else:
      return keep_t_if_possible_handler(info, t)
  copier.transform_external_input_handler = replace_t_with_replacement_handler
  return copier(
      sgv, dst_graph, dst_scope, src_scope, reuse_dst_scope=reuse_dst_scope)


def graph_replace(target_ts, replacement_ts, dst_scope="",
                  src_scope="", reuse_dst_scope=False):
  """Create a new graph which compute the targets from the replaced Tensors.

  Args:
    target_ts: a single tf.Tensor or an iterable of tf.Tensor.
    replacement_ts: dictionary mapping from original tensors to replaced tensors
    dst_scope: the destination scope.
    src_scope: the source scope.
    reuse_dst_scope: if True the dst_scope is re-used if it already exists.
      Otherwise, the scope is given a unique name based on the one given
      by appending an underscore followed by a digit (default).
  Returns:
    A single tf.Tensor or a list of target tf.Tensor, depending on
    the type of the input argument `target_ts`.
    The returned tensors are recomputed using the tensors from replacement_ts.
  Raises:
    ValueError: if the targets are not connected to replacement_ts.
  """
  # Identify operations in the graph that will change.
  # Start forward walk at Tensors that will be replaced, and
  # backward walk at the target output Tensors.
  flatten_target_ts = util.flatten_tree(target_ts)
  # Construct the forward control dependencies edges so that
  # the get_walks_intersection_ops can also traverse the
  # control dependencies.
  graph = util.get_unique_graph(flatten_target_ts, check_types=(tf_ops.Tensor))
  control_ios = util.ControlOutputs(graph)
  ops = select.get_walks_intersection_ops(list(iterkeys(replacement_ts)),
                                          flatten_target_ts,
                                          control_ios=control_ios)
  if not ops:
    raise ValueError("Targets and replacements are not connected!")
  # Create a copy of the relevant subgraph
  _, info = copy_with_input_replacements(
      ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
  # Return the transformed targets but keep the original if the transformed
  # counterpart cannot be found
  missing_fn = lambda original_t: original_t
  return info.transformed(target_ts, missing_fn)