diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-26 01:56:20 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-26 02:08:08 -0800 |
commit | 9a9002f03ec922bc6b7b452616f2d6df6b0d449c (patch) | |
tree | 3f2998e76d4c6e676c5673b5da87bb2ee426132c /tensorflow/contrib/graph_editor | |
parent | 870b7775b2538d78d9353eba53d7664021405013 (diff) |
Fix transform for cyclic graph (second try). Deprecate in-place transform.
Change: 145649225
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/__init__.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/tests/transform_test.py | 56 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 588 |
3 files changed, 268 insertions, 378 deletions
diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py index c59aa2520c..6ae477c8b5 100644 --- a/tensorflow/contrib/graph_editor/__init__.py +++ b/tensorflow/contrib/graph_editor/__init__.py @@ -180,8 +180,8 @@ which to operate must always be given explicitly. This is the reason why @@assign_renamed_collections_handler @@transform_op_if_inside_handler @@copy_op_handler -@@transform_op_in_place @@Transformer +@@TransformerInfo @@copy @@copy_with_input_replacements @@graph_replace diff --git a/tensorflow/contrib/graph_editor/tests/transform_test.py b/tensorflow/contrib/graph_editor/tests/transform_test.py index 2c80c04ce6..99764c4a7e 100644 --- a/tensorflow/contrib/graph_editor/tests/transform_test.py +++ b/tensorflow/contrib/graph_editor/tests/transform_test.py @@ -75,7 +75,7 @@ class TransformTest(test.TestCase): _ = math_ops.add(a, b) sgv = ge.make_view([assert_op, eq.op, a.op, b.op]) copier = ge.Transformer() - copied_sgv, info = copier(sgv, sgv.graph, "", "") + _, info = copier(sgv, sgv.graph, "", "") new_assert_op = info.transformed(assert_op) self.assertIsNotNone(new_assert_op) @@ -84,18 +84,17 @@ class TransformTest(test.TestCase): def my_transform_op_handler(info, op): add_noise = op.name.startswith("Add") - op_ = ge.transform.copy_op_handler(info, op) - if add_noise: - # add some noise to op - with info.graph_.as_default(): - t_ = math_ops.add(constant_op.constant( - 1.0, shape=[10], name="Noise"), - op_.outputs[0], - name="AddNoise") - # return the "noisy" op - return t_.op - else: - return op_ + op_, op_outputs_ = ge.transform.copy_op_handler(info, op) + if not add_noise: + return op_, op_outputs_ + # add some noise to op + with info.graph_.as_default(): + t_ = math_ops.add( + constant_op.constant(1.0, shape=[10], name="Noise"), + op_.outputs[0], + name="AddNoise") + # return the "noisy" op + return op_, [t_] transformer.transform_op_handler = my_transform_op_handler @@ -110,37 +109,6 @@ class TransformTest(test.TestCase): top = ge.select_ops("^AddNoise_2$", graph=graph)[0] self.assertTrue(matcher2(top)) - def test_transform_in_place(self): - transformer = ge.Transformer() - - def my_transform_op_handler_in_place(info, op): - add_noise = op.name.startswith("Add") - op = ge.transform.transform_op_in_place( - info, op, detach_outputs=add_noise) - if add_noise: - # add some noise to op - with info.graph_.as_default(): - t = math_ops.add(constant_op.constant( - 1.0, shape=[10], name="Noise"), - op.outputs[0], - name="AddNoise") - # return the "noisy" op - return t.op - else: - return op - - transformer.transform_op_handler = my_transform_op_handler_in_place - - transformer(self.graph, self.graph, "", "") - matcher0 = ge.matcher("AddNoise").input_ops( - "Noise", ge.matcher("Add").input_ops("Const", "Input")) - matcher1 = ge.matcher("AddNoise_1").input_ops( - "Noise_1", ge.matcher("Add_1").input_ops("Const_1", matcher0)) - matcher2 = ge.matcher("AddNoise_2").input_ops( - "Noise_2", ge.matcher("Add_2").input_ops("Const_2", matcher1)) - top = ge.select_ops("^AddNoise_2$", graph=self.graph)[0] - self.assertTrue(matcher2(top)) - def test_copy_with_input_replacements(self): with self.graph.as_default(): ten = constant_op.constant(10.0, shape=[10], name="Input") diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 832698b8a0..4d272e108f 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -21,12 +21,10 @@ from __future__ import print_function from copy import deepcopy from functools import partial - from six import iteritems from six import iterkeys from six import string_types from six import StringIO -from tensorflow.contrib.graph_editor import edit from tensorflow.contrib.graph_editor import reroute from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph @@ -34,14 +32,15 @@ from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops from tensorflow.python.platform import tf_logging as logging + __all__ = [ "replace_t_with_placeholder_handler", "keep_t_if_possible_handler", "assign_renamed_collections_handler", "transform_op_if_inside_handler", "copy_op_handler", - "transform_op_in_place", "Transformer", + "TransformerInfo", "copy", "copy_with_input_replacements", "graph_replace", @@ -55,7 +54,7 @@ def replace_t_with_placeholder_handler(info, t): placeholder. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. t: tensor whose input must be transformed into a place holder. Returns: The tensor generated by the newly created place holder. @@ -73,7 +72,7 @@ def keep_t_if_possible_handler(info, t): This handler is typically used to transform a hidden input tensors. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. t: tensor whose input must be transformed into a place holder. Returns: The tensor generated by the newly created place holder. @@ -91,7 +90,7 @@ def assign_renamed_collections_handler(info, elem, elem_): `tf.GraphKeys`. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ @@ -103,7 +102,7 @@ def assign_renamed_collections_handler(info, elem, elem_): if name in known_collection_names: transformed_name = name else: - transformed_name = info.transformer.new_name(name) + transformed_name = info.new_name(name) info.graph_.add_to_collection(transformed_name, elem_) @@ -114,18 +113,15 @@ def transform_op_if_inside_handler(info, op, keep_if_possible=True): if they are inside the subgraph, otherwise they are just ignored. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. op: the optional op to transform (or ignore). keep_if_possible: re-attach to the original op if possible, that is, if the source graph and the destination graph are the same. Returns: The transformed op or None. """ - if op is None: - return None if op in info.sgv.ops: - return info.transformer._transform_op( # pylint: disable=protected-access - op) + return info.transformed_ops[op] else: if keep_if_possible and info.graph is info.graph_: return op @@ -137,36 +133,19 @@ def copy_op_handler(info, op, copy_shape=True): """Copy a `tf.Operation`. Args: - info: Transform._Info instance. + info: Transform._TmpInfo instance. op: the `tf.Operation` to be copied. copy_shape: also copy the shape of the tensor Returns: - A copy of op. + A `(op, op_outputs)` tuple containgin the transformed op and its outputs. """ # pylint: disable=protected-access - # Transform control inputs: - control_inputs_ = [info.transformer.transform_control_input_handler(info, ci) - for ci in op.control_inputs] - control_inputs_ = [ci for ci in control_inputs_ if ci is not None] - - # Transform it if any: - original_op_ = info.transformer.transform_original_op_handler(info, - op._original_op) - - # Transform inputs: - inputs_ = [info.transformer._transform_t(t) for t in op.inputs] - - # Leave inputs empty if a graph cycle was found. - if None in inputs_: - info.cyclic_ops.append(op) - inputs_ = [] - # Clone the node def: node_def_ = deepcopy(op._node_def) # Transform name: - name_ = info.transformer.new_name(op.name) + name_ = info.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ @@ -179,45 +158,196 @@ def copy_op_handler(info, op, copy_shape=True): op_def_ = deepcopy(op._op_def) # Initialize a new Operation instance - op_ = tf_ops.Operation(node_def_, info.graph_, inputs_, output_types_, - control_inputs_, input_types_, original_op_, op_def_) + op_ = tf_ops.Operation(node_def_, info.graph_, [], output_types_, + [], input_types_, None, op_def_) # copy the shape over if copy_shape: for t, t_ in zip(op.outputs, op_.outputs): t_.set_shape(t.get_shape()) + # Finalize original op. + if op._original_op: + original_op = info.transform_original_op_handler(info, op._original_op) + if original_op is None: + logging.info("Could not find original op of: %s", op_.name) + else: + op_._original_op = original_op + # Add op to the graph info.graph_._add_op(op_) - # pylint: enable=protected-access - return op_ + return op_, op_.outputs -def transform_op_in_place(info, op, detach_outputs=False): - """Transform a op in-place - experimental! +class TransformerInfo(object): + """"Contains information about the result of a transform operation.""" - Transform an operation in place. It reconnects the inputs if they have been - modified. if detach_outputs is True, the outputs of op are also detached. + def __init__(self, info): + """Constructor. - Args: - info: Transform._Info instance. - op: the op to transform in place. - detach_outputs: if True, the outputs of op are detached, ready for the user - to add more operation. - Returns: - The transformed op. + Args: + info: an instance of Transformer._TmpInfo containing various internal + information about the transform operation. + """ + self._graph = info.graph + self._scope = info.scope + self._graph_ = info.graph_ + self._scope_ = info.scope_ + self._transformed_ops = info.transformed_ops + self._transformed_ts = info.transformed_ts + + def _get_transformed_map(self, top): + """Return the correct container depending on the type of `top`.""" + if isinstance(top, tf_ops.Operation): + return self._transformed_ops + elif isinstance(top, tf_ops.Tensor): + return self._transformed_ts + else: + raise TypeError( + "Expected a tf.Tensor or a tf.Operation, got a {}".format( + type(top))) + + def _transformed_elem(self, original_top, missing_fn=None): + """Return the transformed op/tensor corresponding to the original one. + + Args: + original_top: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the transformed tensor/operation (or None if no match is found). + """ + transformed_map = self._get_transformed_map(original_top) + if isinstance(original_top, string_types): + for original, transformed in iteritems(transformed_map): + if original.name == original_top: + return transformed + return None if missing_fn is None else missing_fn(original_top) + else: + if original_top not in transformed_map: + return None if missing_fn is None else missing_fn(original_top) + return transformed_map[original_top] + + def _original_elem(self, transformed_top, missing_fn=None): + """Return the original op/tensor corresponding to the transformed one. + + Args: + transformed_top: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the original tensor/operation (or None if no match is found). + """ + transformed_map = self._get_transformed_map(transformed_top) + if isinstance(transformed_top, string_types): + finder = lambda transformed: transformed.name == transformed_top + else: + finder = lambda transformed: transformed == transformed_top + for original, transformed in iteritems(transformed_map): + if finder(transformed): + return original + return None if missing_fn is None else missing_fn(transformed_top) + + def transformed(self, original, missing_fn=None): + """Return the transformed op/tensor corresponding to the original one. + + Note that the output of this function mimics the hierarchy + of its input argument `original`. + Given an iterable, it returns a list. Given an operation or a tensor, + it will return an operation or a tensor. + + Args: + original: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the transformed tensor/operation (or None if no match is found). + """ + transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) + return util.transform_tree(original, transformed_elem) + + def original(self, transformed, missing_fn=None): + """Return the original op/tensor corresponding to the transformed one. + + Note that the output of this function mimics the hierarchy + of its input argument `transformed`. + Given an iterable, it returns a list. Given an operation or a tensor, + it will return an operation or a tensor. + + Args: + transformed: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. + Returns: + the original tensor/operation (or None if no match is found). + """ + original_elem = partial(self._original_elem, missing_fn=missing_fn) + return util.transform_tree(transformed, original_elem) + + def __str__(self): + res = StringIO() + print("Transform result info:", file=res) + if self._graph == self._graph_: + in_place_str = "" if self._scope_ else " IN-PLACE" + print(" Within graph[{}]{}".format( + id(self._graph), in_place_str), file=res) + else: + print(" graph[{}] => graph[{}]".format( + id(self._graph), id(self._graph_)), file=res) + if self._scope: + print(" Relative to source scope: {}".format(self._scope), file=res) + if self._scope_: + print(" Scope destination: {}".format(self._scope_), file=res) + print("Operations mapping:", file=res) + for op, op_ in iteritems(self._transformed_ops): + print(" {} => {}".format(op.name, op_.name), file=res) + return res.getvalue() + + +class _TmpInfo(object): + """Transformer temporary data. + + An instance of this class holds all the information relevant to a call + to a transformer instance (that is, a call to __call__). An instance + is created for the life-time of the __call__ function and is passed as + argument to the handlers. """ - # recursive call to the inputs: - inputs = [info.transformer._transform_t(t) # pylint: disable=protected-access - for t in op.inputs] - # re-connect to the inputs if they have changed: - if inputs != list(op.inputs): - reroute.reroute_a2b_ts(inputs, op.inputs) - # detach op from its consumer first ? - if detach_outputs: - edit.detach_outputs(op) - return op + + def __init__(self, sgv, dst_graph, dst_scope, src_scope): + self.sgv = sgv + self.sgv_inputs_set = frozenset(sgv.inputs) + self.ops = frozenset(sgv.ops) + self.control_outputs = util.ControlOutputs(sgv.graph) + self.graph = sgv.graph + self.scope = src_scope + self.graph_ = dst_graph + self.scope_ = dst_scope + self.transformed_ops = {} + self.transformed_ts = {} + self.collections = dict((key, self.graph.get_collection(key)) + for key in self.graph.get_all_collection_keys()) + self.cyclic_ops = [] + self.transform_original_op_handler = transform_op_if_inside_handler + + def new_name(self, name): + """Compute a destination name from a source name. + + Args: + name: the name to be "transformed". + Returns: + The transformed name. + Raises: + ValueError: if the source scope is used (that is, not an empty string) + and the source name does not belong to the source scope. + """ + scope = self.scope + if not name.startswith(scope): + raise ValueError("{} does not belong to source scope: {}.".format( + name, scope)) + rel_name = name[len(scope):] + name_ = self.scope_ + rel_name + return name_ class Transformer(object): @@ -228,155 +358,6 @@ class Transformer(object): the handlers. """ - class _Info(object): - """Transformer temporary data. - - An instance of this class holds all the information relevant to a call - to a transformer instance (that is, a call to __call__). An instance - is created for the life-time of the __call__ function and is passed as - argument to the handlers. - """ - - def __init__(self, transformer, sgv, dst_graph, dst_scope, src_scope): - self.transformer = transformer - self.sgv = sgv - self.sgv_inputs_set = frozenset(sgv.inputs) - self.ops = frozenset(sgv.ops) - self.control_outputs = util.ControlOutputs(sgv.graph) - self.graph = sgv.graph - self.scope = src_scope - self.graph_ = dst_graph - self.scope_ = dst_scope - self.transformed_ops = {} - self.transformed_ts = {} - self.collections = dict((key, self.graph.get_collection(key)) - for key in self.graph.get_all_collection_keys()) - self.cyclic_ops = [] - - class ResultInfo(object): - """"Contains information about the result of a transform operation.""" - - def __init__(self, info): - """Constructor. - - Args: - info: an instance of Transformer._Info containing various internal - information about the transform operation. - """ - self._graph = info.graph - self._scope = info.scope - self._graph_ = info.graph_ - self._scope_ = info.scope_ - self._transformed_ops = info.transformed_ops - self._transformed_ts = info.transformed_ts - - def _get_transformed_map(self, top): - """Return the correct container depending on the type of `top`.""" - if isinstance(top, tf_ops.Operation): - return self._transformed_ops - elif isinstance(top, tf_ops.Tensor): - return self._transformed_ts - else: - raise TypeError( - "Expected a tf.Tensor or a tf.Operation, got a {}".format( - type(top))) - - def _transformed_elem(self, original_top, missing_fn=None): - """Return the transformed op/tensor corresponding to the original one. - - Args: - original_top: the original tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the transformed tensor/operation (or None if no match is found). - """ - transformed_map = self._get_transformed_map(original_top) - if isinstance(original_top, string_types): - for original, transformed in iteritems(transformed_map): - if original.name == original_top: - return transformed - return None if missing_fn is None else missing_fn(original_top) - else: - if original_top not in transformed_map: - return None if missing_fn is None else missing_fn(original_top) - return transformed_map[original_top] - - def _original_elem(self, transformed_top, missing_fn=None): - """Return the original op/tensor corresponding to the transformed one. - - Args: - transformed_top: the transformed tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the original tensor/operation (or None if no match is found). - """ - transformed_map = self._get_transformed_map(transformed_top) - if isinstance(transformed_top, string_types): - finder = lambda transformed: transformed.name == transformed_top - else: - finder = lambda transformed: transformed == transformed_top - for original, transformed in iteritems(transformed_map): - if finder(transformed): - return original - return None if missing_fn is None else missing_fn(transformed_top) - - def transformed(self, original, missing_fn=None): - """Return the transformed op/tensor corresponding to the original one. - - Note that the output of this function mimics the hierarchy - of its input argument `original`. - Given an iterable, it returns a list. Given an operation or a tensor, - it will return an operation or a tensor. - - Args: - original: the original tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the transformed tensor/operation (or None if no match is found). - """ - transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) - return util.transform_tree(original, transformed_elem) - - def original(self, transformed, missing_fn=None): - """Return the original op/tensor corresponding to the transformed one. - - Note that the output of this function mimics the hierarchy - of its input argument `transformed`. - Given an iterable, it returns a list. Given an operation or a tensor, - it will return an operation or a tensor. - - Args: - transformed: the transformed tensor/operation. - missing_fn: function handling the case where the counterpart - cannot be found. By default, None is returned. - Returns: - the original tensor/operation (or None if no match is found). - """ - original_elem = partial(self._original_elem, missing_fn=missing_fn) - return util.transform_tree(transformed, original_elem) - - def __str__(self): - res = StringIO() - print("Transform result info:", file=res) - if self._graph == self._graph_: - in_place_str = "" if self._scope_ else " IN-PLACE" - print(" Within graph[{}]{}".format( - id(self._graph), in_place_str), file=res) - else: - print(" graph[{}] => graph[{}]".format( - id(self._graph), id(self._graph_)), file=res) - if self._scope: - print(" Relative to source scope: {}".format(self._scope), file=res) - if self._scope_: - print(" Scope destination: {}".format(self._scope_), file=res) - print("Operations mapping:", file=res) - for op, op_ in iteritems(self._transformed_ops): - print(" {} => {}".format(op.name, op_.name), file=res) - return res.getvalue() - def __init__(self): """Transformer constructor. @@ -407,9 +388,6 @@ class Transformer(object): self.transform_external_hidden_input_handler = keep_t_if_possible_handler self.transform_original_op_handler = transform_op_if_inside_handler - # temporary per-call variable - self._info = None - def __call__(self, sgv, dst_graph, @@ -432,7 +410,7 @@ class Transformer(object): Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: @@ -450,49 +428,68 @@ class Transformer(object): dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) # Create temporary info used during this transform call - self._info = Transformer._Info(self, sgv, dst_graph, dst_scope, src_scope) - - # Transform the graph starting from the output tensors. - for output_t in self._info.sgv.outputs: - self._transform_t(output_t) - - # Some ops might have been missed by the previous walk, namely, the roots - # without any outputs. So the walk is now finalized from those roots. - remaining_ops = [op for op in self._info.sgv.ops - if op not in self._info.transformed_ops] - remaining_roots = [op for op in remaining_ops if not op.outputs] - for op in remaining_roots: - self._transform_op(op) - - # Finalize cyclic ops: - for op in self._info.cyclic_ops: - logging.debug("Finalizing cyclic op: %s", op.name) - op_ = self._info.transformed_ops[op] - inputs_ = [self._info.transformed_ts[t] for t in op.inputs] - if None in inputs_: - raise ValueError("Could not find all the inputs of cyclic op: {}" - .format(op_.name)) - for input_id, t_ in enumerate(inputs_): - op_._update_input(input_id, t_) # pylint: disable=protected-access - - sgv_ = self._transform_sgv(sgv) - - res_info = Transformer.ResultInfo(self._info) - self._info = None + info = _TmpInfo(sgv, dst_graph, dst_scope, src_scope) + info.transform_original_op_handler = self.transform_original_op_handler + + self._copy_ops(info) + self._connect_ops(info) + + # Compute information about the transformation + res_info = TransformerInfo(info) + sgv_ = self._transform_sgv(info, sgv) return sgv_, res_info - def _transform_sgv(self, sgv): + def _copy_ops(self, info): + """Copy ops without connecting them.""" + for op in info.sgv.ops: + logging.info("Copying op: %s", op.name) + # TODO(fkp): return a subgraph? + op_, op_outputs_ = self.transform_op_handler(info, op) + if op is op_: + raise ValueError("In-place tranformation not allowed.") + + # Process op. + info.transformed_ops[op] = op_ + self.assign_collections_handler(info, op, op_) + + # Process output tensors. + for op_output, op_output_ in zip(op.outputs, op_outputs_): + info.transformed_ts[op_output] = op_output_ + self.assign_collections_handler(info, op_output, op_output_) + + def _connect_ops(self, info): + """Connect the previously copied ops.""" + for op in info.sgv.ops: + logging.info("Finalizing op: %s", op.name) + op_ = info.transformed_ops[op] + + # pylint: disable=protected-access + if op_.inputs: + raise ValueError("The newly transformed op should not have " + "any inputs yet: {}".format(op_.name)) + inputs_ = [self._transformed_t(info, t) for t in op.inputs] + for t in inputs_: + op_._add_input(t) + + # Finalize control inputs: + control_inputs_ = [self.transform_control_input_handler(info, ci) + for ci in op.control_inputs] + control_inputs_ = [ci for ci in control_inputs_ if ci is not None] + reroute.add_control_inputs(op_, control_inputs_) + + def _transform_sgv(self, info, sgv): """Transform a subgraph view. For convenience, a transform operation returns a subgraph view of the transformed graph. Args: + info: Temporary information for this transorfm call. sgv: the subgraph to be transformed. Returns: The transformed subgraph. """ - ops_ = [op_ for _, op_ in iteritems(self._info.transformed_ops)] + ops_ = [op_ for _, op_ in iteritems(info.transformed_ops)] sgv_ = subgraph.SubGraphView(ops_) sgv_inputs_ = sgv_.inputs sgv_outputs_ = sgv_.outputs @@ -500,9 +497,9 @@ class Transformer(object): # re-order inputs input_map_ = [] for input_t in sgv.inputs: - if input_t not in self._info.transformed_ts: + if input_t not in info.transformed_ts: continue - input_t_ = self._info.transformed_ts[input_t] + input_t_ = info.transformed_ts[input_t] if input_t_ not in sgv_inputs_: continue input_t_index_ = sgv_.input_index(input_t_) @@ -511,9 +508,9 @@ class Transformer(object): # re-order outputs output_map_ = [] for output_t in sgv.outputs: - if output_t not in self._info.transformed_ts: + if output_t not in info.transformed_ts: continue - output_t_ = self._info.transformed_ts[output_t] + output_t_ = info.transformed_ts[output_t] if output_t_ not in sgv_outputs_: continue output_t_index_ = sgv_.output_index(output_t_) @@ -521,94 +518,19 @@ class Transformer(object): return sgv_.remap(input_map_, output_map_) - def _transform_t(self, t): - """Transform a tf.Tensor. - - Args: - t: the tensor to be transformed. - Returns: - The transformed tensor. - """ - logging.debug("Transforming tensor: %s", t.name) - if t in self._info.transformed_ts: - return self._info.transformed_ts[t] - - # Mark as None to detect cycle. - self._info.transformed_ts[t] = None - - op, op_index = t.op, t.value_index - - # If op is not in the subgraph: - if op not in self._info.ops: - # t_ is an input of the subgraph - if t in self._info.sgv_inputs_set: - t_ = self.transform_external_input_handler(self._info, t) - # t_ is a hidden input of the subgraph + def _transformed_t(self, info, t): + """Return tre transformed tensor of `t`.""" + if t not in info.transformed_ts: + # If op is not in the subgraph. + if t in info.sgv_inputs_set: + # t is an input of the subgraph. + return self.transform_external_input_handler(info, t) else: - t_ = self.transform_external_hidden_input_handler(self._info, t) - # If op is in the subgraph, just transform it: + # t is a hidden input of the subgraph. + return self.transform_external_hidden_input_handler(info, t) else: - op_ = self._transform_op(op) - t_ = op_.outputs[op_index] - - # assign to collection - if t is not t_: - self.assign_collections_handler(self._info, t, t_) - - self._info.transformed_ts[t] = t_ - return t_ - - def _transform_op(self, op): - """Transform a tf.Operation. - - Args: - op: the operation to be transformed. - Returns: - The transformed operation. - """ - if op in self._info.transformed_ops: - return self._info.transformed_ops[op] - - op_ = self.transform_op_handler(self._info, op) - - # Add to all the active control dependencies - # pylint: disable=protected-access - self._info.graph_._record_op_seen_by_control_dependencies(op_) - - # All to all the active devices - for device_function in reversed(self._info.graph_._device_function_stack): - if device_function is None: - break - op_._set_device(device_function(op_)) - # pylint: enable=protected-access - - # TODO(fkp): Establish clear policy about what context managers are allowed. - - # assign to collection - if op is not op_: - self.assign_collections_handler(self._info, op, op_) - - self._info.transformed_ops[op] = op_ - return op_ - - def new_name(self, name): - """Compute a destination name from a source name. - - Args: - name: the name to be "transformed". - Returns: - The transformed name. - Raises: - ValueError: if the source scope is used (that is, not an empty string) - and the source name does not belong to the source scope. - """ - scope = self._info.scope - if not name.startswith(scope): - raise ValueError("{} does not belong to source scope: {}.".format(name, - scope)) - rel_name = name[len(scope):] - name_ = self._info.scope_ + rel_name - return name_ + # If op is in the subgraph, just return its transformed. + return info.transformed_ts[t] def copy(sgv, dst_graph=None, dst_scope="", src_scope="", @@ -627,7 +549,7 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope="", Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: @@ -669,7 +591,7 @@ def copy_with_input_replacements(sgv, replacement_ts, Returns: A tuple `(sgv, info)` where: `sgv` is the transformed subgraph view; - `info` is an instance of Transformer.ResultInfo containing + `info` is an instance of TransformerInfo containing information about the transform, including mapping between original and transformed tensors and operations. Raises: |