aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-26 01:56:20 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-26 02:08:08 -0800
commit9a9002f03ec922bc6b7b452616f2d6df6b0d449c (patch)
tree3f2998e76d4c6e676c5673b5da87bb2ee426132c /tensorflow/contrib/graph_editor
parent870b7775b2538d78d9353eba53d7664021405013 (diff)
Fix transform for cyclic graph (second try). Deprecate in-place transform.
Change: 145649225
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/__init__.py2
-rw-r--r--tensorflow/contrib/graph_editor/tests/transform_test.py56
-rw-r--r--tensorflow/contrib/graph_editor/transform.py588
3 files changed, 268 insertions, 378 deletions
diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py
index c59aa2520c..6ae477c8b5 100644
--- a/tensorflow/contrib/graph_editor/__init__.py
+++ b/tensorflow/contrib/graph_editor/__init__.py
@@ -180,8 +180,8 @@ which to operate must always be given explicitly. This is the reason why
@@assign_renamed_collections_handler
@@transform_op_if_inside_handler
@@copy_op_handler
-@@transform_op_in_place
@@Transformer
+@@TransformerInfo
@@copy
@@copy_with_input_replacements
@@graph_replace
diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py
index 2c80c04ce6..99764c4a7e 100644
--- a/tensorflow/contrib/graph_editor/tests/transform_test.py
+++ b/tensorflow/contrib/graph_editor/tests/transform_test.py
@@ -75,7 +75,7 @@ class TransformTest(test.TestCase):
_ = math_ops.add(a, b)
sgv = ge.make_view([assert_op, eq.op, a.op, b.op])
copier = ge.Transformer()
- copied_sgv, info = copier(sgv, sgv.graph, "", "")
+ _, info = copier(sgv, sgv.graph, "", "")
new_assert_op = info.transformed(assert_op)
self.assertIsNotNone(new_assert_op)
@@ -84,18 +84,17 @@ class TransformTest(test.TestCase):
def my_transform_op_handler(info, op):
add_noise = op.name.startswith("Add")
- op_ = ge.transform.copy_op_handler(info, op)
- if add_noise:
- # add some noise to op
- with info.graph_.as_default():
- t_ = math_ops.add(constant_op.constant(
- 1.0, shape=[10], name="Noise"),
- op_.outputs[0],
- name="AddNoise")
- # return the "noisy" op
- return t_.op
- else:
- return op_
+ op_, op_outputs_ = ge.transform.copy_op_handler(info, op)
+ if not add_noise:
+ return op_, op_outputs_
+ # add some noise to op
+ with info.graph_.as_default():
+ t_ = math_ops.add(
+ constant_op.constant(1.0, shape=[10], name="Noise"),
+ op_.outputs[0],
+ name="AddNoise")
+ # return the "noisy" op
+ return op_, [t_]
transformer.transform_op_handler = my_transform_op_handler
@@ -110,37 +109,6 @@ class TransformTest(test.TestCase):
top = ge.select_ops("^AddNoise_2$", graph=graph)[0]
self.assertTrue(matcher2(top))
- def test_transform_in_place(self):
- transformer = ge.Transformer()
-
- def my_transform_op_handler_in_place(info, op):
- add_noise = op.name.startswith("Add")
- op = ge.transform.transform_op_in_place(
- info, op, detach_outputs=add_noise)
- if add_noise:
- # add some noise to op
- with info.graph_.as_default():
- t = math_ops.add(constant_op.constant(
- 1.0, shape=[10], name="Noise"),
- op.outputs[0],
- name="AddNoise")
- # return the "noisy" op
- return t.op
- else:
- return op
-
- transformer.transform_op_handler = my_transform_op_handler_in_place
-
- transformer(self.graph, self.graph, "", "")
- matcher0 = ge.matcher("AddNoise").input_ops(
- "Noise", ge.matcher("Add").input_ops("Const", "Input"))
- matcher1 = ge.matcher("AddNoise_1").input_ops(
- "Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0))
- matcher2 = ge.matcher("AddNoise_2").input_ops(
- "Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1))
- top = ge.select_ops("^AddNoise_2$", graph=self.graph)[0]
- self.assertTrue(matcher2(top))
-
def test_copy_with_input_replacements(self):
with self.graph.as_default():
ten = constant_op.constant(10.0, shape=[10], name="Input")
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 832698b8a0..4d272e108f 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -21,12 +21,10 @@ from __future__ import print_function
from copy import deepcopy
from functools import partial
-
from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
-from tensorflow.contrib.graph_editor import edit
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
@@ -34,14 +32,15 @@ from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import tf_logging as logging
+
__all__ = [
"replace_t_with_placeholder_handler",
"keep_t_if_possible_handler",
"assign_renamed_collections_handler",
"transform_op_if_inside_handler",
"copy_op_handler",
- "transform_op_in_place",
"Transformer",
+ "TransformerInfo",
"copy",
"copy_with_input_replacements",
"graph_replace",
@@ -55,7 +54,7 @@ def replace_t_with_placeholder_handler(info, t):
placeholder.
Args:
- info: Transform._Info instance.
+ info: Transform._TmpInfo instance.
t: tensor whose input must be transformed into a place holder.
Returns:
The tensor generated by the newly created place holder.
@@ -73,7 +72,7 @@ def keep_t_if_possible_handler(info, t):
This handler is typically used to transform a hidden input tensors.
Args:
- info: Transform._Info instance.
+ info: Transform._TmpInfo instance.
t: tensor whose input must be transformed into a place holder.
Returns:
The tensor generated by the newly created place holder.
@@ -91,7 +90,7 @@ def assign_renamed_collections_handler(info, elem, elem_):
`tf.GraphKeys`.
Args:
- info: Transform._Info instance.
+ info: Transform._TmpInfo instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
@@ -103,7 +102,7 @@ def assign_renamed_collections_handler(info, elem, elem_):
if name in known_collection_names:
transformed_name = name
else:
- transformed_name = info.transformer.new_name(name)
+ transformed_name = info.new_name(name)
info.graph_.add_to_collection(transformed_name, elem_)
@@ -114,18 +113,15 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True):
if they are inside the subgraph, otherwise they are just ignored.
Args:
- info: Transform._Info instance.
+ info: Transform._TmpInfo instance.
op: the optional op to transform (or ignore).
keep_if_possible: re-attach to the original op if possible, that is,
if the source graph and the destination graph are the same.
Returns:
The transformed op or None.
"""
- if op is None:
- return None
if op in info.sgv.ops:
- return info.transformer._transform_op( # pylint: disable=protected-access
- op)
+ return info.transformed_ops[op]
else:
if keep_if_possible and info.graph is info.graph_:
return op
@@ -137,36 +133,19 @@ def copy_op_handler(info, op, copy_shape=True):
"""Copy a `tf.Operation`.
Args:
- info: Transform._Info instance.
+ info: Transform._TmpInfo instance.
op: the `tf.Operation` to be copied.
copy_shape: also copy the shape of the tensor
Returns:
- A copy of op.
+ A `(op, op_outputs)` tuple containgin the transformed op and its outputs.
"""
# pylint: disable=protected-access
- # Transform control inputs:
- control_inputs_ = [info.transformer.transform_control_input_handler(info, ci)
- for ci in op.control_inputs]
- control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
-
- # Transform it if any:
- original_op_ = info.transformer.transform_original_op_handler(info,
- op._original_op)
-
- # Transform inputs:
- inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
-
- # Leave inputs empty if a graph cycle was found.
- if None in inputs_:
- info.cyclic_ops.append(op)
- inputs_ = []
-
# Clone the node def:
node_def_ = deepcopy(op._node_def)
# Transform name:
- name_ = info.transformer.new_name(op.name)
+ name_ = info.new_name(op.name)
name_ = info.graph_.unique_name(name_)
node_def_.name = name_
@@ -179,45 +158,196 @@ def copy_op_handler(info, op, copy_shape=True):
op_def_ = deepcopy(op._op_def)
# Initialize a new Operation instance
- op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_,
- control_inputs_, input_types_, original_op_, op_def_)
+ op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_,
+ [], input_types_, None, op_def_)
# copy the shape over
if copy_shape:
for t, t_ in zip(op.outputs, op_.outputs):
t_.set_shape(t.get_shape())
+ # Finalize original op.
+ if op._original_op:
+ original_op = info.transform_original_op_handler(info, op._original_op)
+ if original_op is None:
+ logging.info("Could not find original op of: %s", op_.name)
+ else:
+ op_._original_op = original_op
+
# Add op to the graph
info.graph_._add_op(op_)
- # pylint: enable=protected-access
- return op_
+ return op_, op_.outputs
-def transform_op_in_place(info, op, detach_outputs=False):
- """Transform a op in-place - experimental!
+class TransformerInfo(object):
+ """"Contains information about the result of a transform operation."""
- Transform an operation in place. It reconnects the inputs if they have been
- modified. if detach_outputs is True, the outputs of op are also detached.
+ def __init__(self, info):
+ """Constructor.
- Args:
- info: Transform._Info instance.
- op: the op to transform in place.
- detach_outputs: if True, the outputs of op are detached, ready for the user
- to add more operation.
- Returns:
- The transformed op.
+ Args:
+ info: an instance of Transformer._TmpInfo containing various internal
+ information about the transform operation.
+ """
+ self._graph = info.graph
+ self._scope = info.scope
+ self._graph_ = info.graph_
+ self._scope_ = info.scope_
+ self._transformed_ops = info.transformed_ops
+ self._transformed_ts = info.transformed_ts
+
+ def _get_transformed_map(self, top):
+ """Return the correct container depending on the type of `top`."""
+ if isinstance(top, tf_ops.Operation):
+ return self._transformed_ops
+ elif isinstance(top, tf_ops.Tensor):
+ return self._transformed_ts
+ else:
+ raise TypeError(
+ "Expected a tf.Tensor or a tf.Operation, got a {}".format(
+ type(top)))
+
+ def _transformed_elem(self, original_top, missing_fn=None):
+ """Return the transformed op/tensor corresponding to the original one.
+
+ Args:
+ original_top: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the transformed tensor/operation (or None if no match is found).
+ """
+ transformed_map = self._get_transformed_map(original_top)
+ if isinstance(original_top, string_types):
+ for original, transformed in iteritems(transformed_map):
+ if original.name == original_top:
+ return transformed
+ return None if missing_fn is None else missing_fn(original_top)
+ else:
+ if original_top not in transformed_map:
+ return None if missing_fn is None else missing_fn(original_top)
+ return transformed_map[original_top]
+
+ def _original_elem(self, transformed_top, missing_fn=None):
+ """Return the original op/tensor corresponding to the transformed one.
+
+ Args:
+ transformed_top: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the original tensor/operation (or None if no match is found).
+ """
+ transformed_map = self._get_transformed_map(transformed_top)
+ if isinstance(transformed_top, string_types):
+ finder = lambda transformed: transformed.name == transformed_top
+ else:
+ finder = lambda transformed: transformed == transformed_top
+ for original, transformed in iteritems(transformed_map):
+ if finder(transformed):
+ return original
+ return None if missing_fn is None else missing_fn(transformed_top)
+
+ def transformed(self, original, missing_fn=None):
+ """Return the transformed op/tensor corresponding to the original one.
+
+ Note that the output of this function mimics the hierarchy
+ of its input argument `original`.
+ Given an iterable, it returns a list. Given an operation or a tensor,
+ it will return an operation or a tensor.
+
+ Args:
+ original: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the transformed tensor/operation (or None if no match is found).
+ """
+ transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
+ return util.transform_tree(original, transformed_elem)
+
+ def original(self, transformed, missing_fn=None):
+ """Return the original op/tensor corresponding to the transformed one.
+
+ Note that the output of this function mimics the hierarchy
+ of its input argument `transformed`.
+ Given an iterable, it returns a list. Given an operation or a tensor,
+ it will return an operation or a tensor.
+
+ Args:
+ transformed: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
+ Returns:
+ the original tensor/operation (or None if no match is found).
+ """
+ original_elem = partial(self._original_elem, missing_fn=missing_fn)
+ return util.transform_tree(transformed, original_elem)
+
+ def __str__(self):
+ res = StringIO()
+ print("Transform result info:", file=res)
+ if self._graph == self._graph_:
+ in_place_str = "" if self._scope_ else " IN-PLACE"
+ print(" Within graph[{}]{}".format(
+ id(self._graph), in_place_str), file=res)
+ else:
+ print(" graph[{}] => graph[{}]".format(
+ id(self._graph), id(self._graph_)), file=res)
+ if self._scope:
+ print(" Relative to source scope: {}".format(self._scope), file=res)
+ if self._scope_:
+ print(" Scope destination: {}".format(self._scope_), file=res)
+ print("Operations mapping:", file=res)
+ for op, op_ in iteritems(self._transformed_ops):
+ print(" {} => {}".format(op.name, op_.name), file=res)
+ return res.getvalue()
+
+
+class _TmpInfo(object):
+ """Transformer temporary data.
+
+ An instance of this class holds all the information relevant to a call
+ to a transformer instance (that is, a call to __call__). An instance
+ is created for the life-time of the __call__ function and is passed as
+ argument to the handlers.
"""
- # recursive call to the inputs:
- inputs = [info.transformer._transform_t(t) # pylint: disable=protected-access
- for t in op.inputs]
- # re-connect to the inputs if they have changed:
- if inputs != list(op.inputs):
- reroute.reroute_a2b_ts(inputs, op.inputs)
- # detach op from its consumer first ?
- if detach_outputs:
- edit.detach_outputs(op)
- return op
+
+ def __init__(self, sgv, dst_graph, dst_scope, src_scope):
+ self.sgv = sgv
+ self.sgv_inputs_set = frozenset(sgv.inputs)
+ self.ops = frozenset(sgv.ops)
+ self.control_outputs = util.ControlOutputs(sgv.graph)
+ self.graph = sgv.graph
+ self.scope = src_scope
+ self.graph_ = dst_graph
+ self.scope_ = dst_scope
+ self.transformed_ops = {}
+ self.transformed_ts = {}
+ self.collections = dict((key, self.graph.get_collection(key))
+ for key in self.graph.get_all_collection_keys())
+ self.cyclic_ops = []
+ self.transform_original_op_handler = transform_op_if_inside_handler
+
+ def new_name(self, name):
+ """Compute a destination name from a source name.
+
+ Args:
+ name: the name to be "transformed".
+ Returns:
+ The transformed name.
+ Raises:
+ ValueError: if the source scope is used (that is, not an empty string)
+ and the source name does not belong to the source scope.
+ """
+ scope = self.scope
+ if not name.startswith(scope):
+ raise ValueError("{} does not belong to source scope: {}.".format(
+ name, scope))
+ rel_name = name[len(scope):]
+ name_ = self.scope_ + rel_name
+ return name_
class Transformer(object):
@@ -228,155 +358,6 @@ class Transformer(object):
the handlers.
"""
- class _Info(object):
- """Transformer temporary data.
-
- An instance of this class holds all the information relevant to a call
- to a transformer instance (that is, a call to __call__). An instance
- is created for the life-time of the __call__ function and is passed as
- argument to the handlers.
- """
-
- def __init__(self, transformer, sgv, dst_graph, dst_scope, src_scope):
- self.transformer = transformer
- self.sgv = sgv
- self.sgv_inputs_set = frozenset(sgv.inputs)
- self.ops = frozenset(sgv.ops)
- self.control_outputs = util.ControlOutputs(sgv.graph)
- self.graph = sgv.graph
- self.scope = src_scope
- self.graph_ = dst_graph
- self.scope_ = dst_scope
- self.transformed_ops = {}
- self.transformed_ts = {}
- self.collections = dict((key, self.graph.get_collection(key))
- for key in self.graph.get_all_collection_keys())
- self.cyclic_ops = []
-
- class ResultInfo(object):
- """"Contains information about the result of a transform operation."""
-
- def __init__(self, info):
- """Constructor.
-
- Args:
- info: an instance of Transformer._Info containing various internal
- information about the transform operation.
- """
- self._graph = info.graph
- self._scope = info.scope
- self._graph_ = info.graph_
- self._scope_ = info.scope_
- self._transformed_ops = info.transformed_ops
- self._transformed_ts = info.transformed_ts
-
- def _get_transformed_map(self, top):
- """Return the correct container depending on the type of `top`."""
- if isinstance(top, tf_ops.Operation):
- return self._transformed_ops
- elif isinstance(top, tf_ops.Tensor):
- return self._transformed_ts
- else:
- raise TypeError(
- "Expected a tf.Tensor or a tf.Operation, got a {}".format(
- type(top)))
-
- def _transformed_elem(self, original_top, missing_fn=None):
- """Return the transformed op/tensor corresponding to the original one.
-
- Args:
- original_top: the original tensor/operation.
- missing_fn: function handling the case where the counterpart
- cannot be found. By default, None is returned.
- Returns:
- the transformed tensor/operation (or None if no match is found).
- """
- transformed_map = self._get_transformed_map(original_top)
- if isinstance(original_top, string_types):
- for original, transformed in iteritems(transformed_map):
- if original.name == original_top:
- return transformed
- return None if missing_fn is None else missing_fn(original_top)
- else:
- if original_top not in transformed_map:
- return None if missing_fn is None else missing_fn(original_top)
- return transformed_map[original_top]
-
- def _original_elem(self, transformed_top, missing_fn=None):
- """Return the original op/tensor corresponding to the transformed one.
-
- Args:
- transformed_top: the transformed tensor/operation.
- missing_fn: function handling the case where the counterpart
- cannot be found. By default, None is returned.
- Returns:
- the original tensor/operation (or None if no match is found).
- """
- transformed_map = self._get_transformed_map(transformed_top)
- if isinstance(transformed_top, string_types):
- finder = lambda transformed: transformed.name == transformed_top
- else:
- finder = lambda transformed: transformed == transformed_top
- for original, transformed in iteritems(transformed_map):
- if finder(transformed):
- return original
- return None if missing_fn is None else missing_fn(transformed_top)
-
- def transformed(self, original, missing_fn=None):
- """Return the transformed op/tensor corresponding to the original one.
-
- Note that the output of this function mimics the hierarchy
- of its input argument `original`.
- Given an iterable, it returns a list. Given an operation or a tensor,
- it will return an operation or a tensor.
-
- Args:
- original: the original tensor/operation.
- missing_fn: function handling the case where the counterpart
- cannot be found. By default, None is returned.
- Returns:
- the transformed tensor/operation (or None if no match is found).
- """
- transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
- return util.transform_tree(original, transformed_elem)
-
- def original(self, transformed, missing_fn=None):
- """Return the original op/tensor corresponding to the transformed one.
-
- Note that the output of this function mimics the hierarchy
- of its input argument `transformed`.
- Given an iterable, it returns a list. Given an operation or a tensor,
- it will return an operation or a tensor.
-
- Args:
- transformed: the transformed tensor/operation.
- missing_fn: function handling the case where the counterpart
- cannot be found. By default, None is returned.
- Returns:
- the original tensor/operation (or None if no match is found).
- """
- original_elem = partial(self._original_elem, missing_fn=missing_fn)
- return util.transform_tree(transformed, original_elem)
-
- def __str__(self):
- res = StringIO()
- print("Transform result info:", file=res)
- if self._graph == self._graph_:
- in_place_str = "" if self._scope_ else " IN-PLACE"
- print(" Within graph[{}]{}".format(
- id(self._graph), in_place_str), file=res)
- else:
- print(" graph[{}] => graph[{}]".format(
- id(self._graph), id(self._graph_)), file=res)
- if self._scope:
- print(" Relative to source scope: {}".format(self._scope), file=res)
- if self._scope_:
- print(" Scope destination: {}".format(self._scope_), file=res)
- print("Operations mapping:", file=res)
- for op, op_ in iteritems(self._transformed_ops):
- print(" {} => {}".format(op.name, op_.name), file=res)
- return res.getvalue()
-
def __init__(self):
"""Transformer constructor.
@@ -407,9 +388,6 @@ class Transformer(object):
self.transform_external_hidden_input_handler = keep_t_if_possible_handler
self.transform_original_op_handler = transform_op_if_inside_handler
- # temporary per-call variable
- self._info = None
-
def __call__(self,
sgv,
dst_graph,
@@ -432,7 +410,7 @@ class Transformer(object):
Returns:
A tuple `(sgv, info)` where:
`sgv` is the transformed subgraph view;
- `info` is an instance of Transformer.ResultInfo containing
+ `info` is an instance of TransformerInfo containing
information about the transform, including mapping between
original and transformed tensors and operations.
Raises:
@@ -450,49 +428,68 @@ class Transformer(object):
dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))
# Create temporary info used during this transform call
- self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope)
-
- # Transform the graph starting from the output tensors.
- for output_t in self._info.sgv.outputs:
- self._transform_t(output_t)
-
- # Some ops might have been missed by the previous walk, namely, the roots
- # without any outputs. So the walk is now finalized from those roots.
- remaining_ops = [op for op in self._info.sgv.ops
- if op not in self._info.transformed_ops]
- remaining_roots = [op for op in remaining_ops if not op.outputs]
- for op in remaining_roots:
- self._transform_op(op)
-
- # Finalize cyclic ops:
- for op in self._info.cyclic_ops:
- logging.debug("Finalizing cyclic op: %s", op.name)
- op_ = self._info.transformed_ops[op]
- inputs_ = [self._info.transformed_ts[t] for t in op.inputs]
- if None in inputs_:
- raise ValueError("Could not find all the inputs of cyclic op: {}"
- .format(op_.name))
- for input_id, t_ in enumerate(inputs_):
- op_._update_input(input_id, t_) # pylint: disable=protected-access
-
- sgv_ = self._transform_sgv(sgv)
-
- res_info = Transformer.ResultInfo(self._info)
- self._info = None
+ info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope)
+ info.transform_original_op_handler = self.transform_original_op_handler
+
+ self._copy_ops(info)
+ self._connect_ops(info)
+
+ # Compute information about the transformation
+ res_info = TransformerInfo(info)
+ sgv_ = self._transform_sgv(info, sgv)
return sgv_, res_info
- def _transform_sgv(self, sgv):
+ def _copy_ops(self, info):
+ """Copy ops without connecting them."""
+ for op in info.sgv.ops:
+ logging.info("Copying op: %s", op.name)
+ # TODO(fkp): return a subgraph?
+ op_, op_outputs_ = self.transform_op_handler(info, op)
+ if op is op_:
+ raise ValueError("In-place tranformation not allowed.")
+
+ # Process op.
+ info.transformed_ops[op] = op_
+ self.assign_collections_handler(info, op, op_)
+
+ # Process output tensors.
+ for op_output, op_output_ in zip(op.outputs, op_outputs_):
+ info.transformed_ts[op_output] = op_output_
+ self.assign_collections_handler(info, op_output, op_output_)
+
+ def _connect_ops(self, info):
+ """Connect the previously copied ops."""
+ for op in info.sgv.ops:
+ logging.info("Finalizing op: %s", op.name)
+ op_ = info.transformed_ops[op]
+
+ # pylint: disable=protected-access
+ if op_.inputs:
+ raise ValueError("The newly transformed op should not have "
+ "any inputs yet: {}".format(op_.name))
+ inputs_ = [self._transformed_t(info, t) for t in op.inputs]
+ for t in inputs_:
+ op_._add_input(t)
+
+ # Finalize control inputs:
+ control_inputs_ = [self.transform_control_input_handler(info, ci)
+ for ci in op.control_inputs]
+ control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
+ reroute.add_control_inputs(op_, control_inputs_)
+
+ def _transform_sgv(self, info, sgv):
"""Transform a subgraph view.
For convenience, a transform operation returns a subgraph view of the
transformed graph.
Args:
+ info: Temporary information for this transorfm call.
sgv: the subgraph to be transformed.
Returns:
The transformed subgraph.
"""
- ops_ = [op_ for _, op_ in iteritems(self._info.transformed_ops)]
+ ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)]
sgv_ = subgraph.SubGraphView(ops_)
sgv_inputs_ = sgv_.inputs
sgv_outputs_ = sgv_.outputs
@@ -500,9 +497,9 @@ class Transformer(object):
# re-order inputs
input_map_ = []
for input_t in sgv.inputs:
- if input_t not in self._info.transformed_ts:
+ if input_t not in info.transformed_ts:
continue
- input_t_ = self._info.transformed_ts[input_t]
+ input_t_ = info.transformed_ts[input_t]
if input_t_ not in sgv_inputs_:
continue
input_t_index_ = sgv_.input_index(input_t_)
@@ -511,9 +508,9 @@ class Transformer(object):
# re-order outputs
output_map_ = []
for output_t in sgv.outputs:
- if output_t not in self._info.transformed_ts:
+ if output_t not in info.transformed_ts:
continue
- output_t_ = self._info.transformed_ts[output_t]
+ output_t_ = info.transformed_ts[output_t]
if output_t_ not in sgv_outputs_:
continue
output_t_index_ = sgv_.output_index(output_t_)
@@ -521,94 +518,19 @@ class Transformer(object):
return sgv_.remap(input_map_, output_map_)
- def _transform_t(self, t):
- """Transform a tf.Tensor.
-
- Args:
- t: the tensor to be transformed.
- Returns:
- The transformed tensor.
- """
- logging.debug("Transforming tensor: %s", t.name)
- if t in self._info.transformed_ts:
- return self._info.transformed_ts[t]
-
- # Mark as None to detect cycle.
- self._info.transformed_ts[t] = None
-
- op, op_index = t.op, t.value_index
-
- # If op is not in the subgraph:
- if op not in self._info.ops:
- # t_ is an input of the subgraph
- if t in self._info.sgv_inputs_set:
- t_ = self.transform_external_input_handler(self._info, t)
- # t_ is a hidden input of the subgraph
+ def _transformed_t(self, info, t):
+ """Return tre transformed tensor of `t`."""
+ if t not in info.transformed_ts:
+ # If op is not in the subgraph.
+ if t in info.sgv_inputs_set:
+ # t is an input of the subgraph.
+ return self.transform_external_input_handler(info, t)
else:
- t_ = self.transform_external_hidden_input_handler(self._info, t)
- # If op is in the subgraph, just transform it:
+ # t is a hidden input of the subgraph.
+ return self.transform_external_hidden_input_handler(info, t)
else:
- op_ = self._transform_op(op)
- t_ = op_.outputs[op_index]
-
- # assign to collection
- if t is not t_:
- self.assign_collections_handler(self._info, t, t_)
-
- self._info.transformed_ts[t] = t_
- return t_
-
- def _transform_op(self, op):
- """Transform a tf.Operation.
-
- Args:
- op: the operation to be transformed.
- Returns:
- The transformed operation.
- """
- if op in self._info.transformed_ops:
- return self._info.transformed_ops[op]
-
- op_ = self.transform_op_handler(self._info, op)
-
- # Add to all the active control dependencies
- # pylint: disable=protected-access
- self._info.graph_._record_op_seen_by_control_dependencies(op_)
-
- # All to all the active devices
- for device_function in reversed(self._info.graph_._device_function_stack):
- if device_function is None:
- break
- op_._set_device(device_function(op_))
- # pylint: enable=protected-access
-
- # TODO(fkp): Establish clear policy about what context managers are allowed.
-
- # assign to collection
- if op is not op_:
- self.assign_collections_handler(self._info, op, op_)
-
- self._info.transformed_ops[op] = op_
- return op_
-
- def new_name(self, name):
- """Compute a destination name from a source name.
-
- Args:
- name: the name to be "transformed".
- Returns:
- The transformed name.
- Raises:
- ValueError: if the source scope is used (that is, not an empty string)
- and the source name does not belong to the source scope.
- """
- scope = self._info.scope
- if not name.startswith(scope):
- raise ValueError("{} does not belong to source scope: {}.".format(name,
- scope))
- rel_name = name[len(scope):]
- name_ = self._info.scope_ + rel_name
- return name_
+ # If op is in the subgraph, just return its transformed.
+ return info.transformed_ts[t]
def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
@@ -627,7 +549,7 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
Returns:
A tuple `(sgv, info)` where:
`sgv` is the transformed subgraph view;
- `info` is an instance of Transformer.ResultInfo containing
+ `info` is an instance of TransformerInfo containing
information about the transform, including mapping between
original and transformed tensors and operations.
Raises:
@@ -669,7 +591,7 @@ def copy_with_input_replacements(sgv, replacement_ts,
Returns:
A tuple `(sgv, info)` where:
`sgv` is the transformed subgraph view;
- `info` is an instance of Transformer.ResultInfo containing
+ `info` is an instance of TransformerInfo containing
information about the transform, including mapping between
original and transformed tensors and operations.
Raises: