diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-01-18 05:53:24 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-01-18 06:06:32 -0800 |
commit | e28da73a92273748c768abbb72502f56e93b9ce3 (patch) | |
tree | 995ff30543d46c76196f14512af775a585203cc7 /tensorflow/contrib/graph_editor | |
parent | 1c7ef3db0840d6530e7dde693b9351b8205f7e84 (diff) |
Fix transform for cyclic graph.
Improve collection name handling.
Added helper to retrieve corresponding tensor/op.
Change: 144824657
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 37 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/util.py | 73 |
2 files changed, 105 insertions, 5 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 6fb347c834..832698b8a0 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -26,13 +26,13 @@ from six import iteritems from six import iterkeys from six import string_types from six import StringIO - from tensorflow.contrib.graph_editor import edit from tensorflow.contrib.graph_editor import reroute from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.platform import tf_logging as logging __all__ = [ "replace_t_with_placeholder_handler", @@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t): def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. + A collection is renamed only if is not a known key, as described in + `tf.GraphKeys`. + Args: info: Transform._Info instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ - # TODO(fkp): handle known special cases + known_collection_names = util.get_predefined_collection_names() for name, collection in iteritems(info.collections): if elem not in collection: continue - collection_name_ = info.transformer.new_name(name) - info.graph_.add_to_collection(collection_name_, elem_) + + if name in known_collection_names: + transformed_name = name + else: + transformed_name = info.transformer.new_name(name) + info.graph_.add_to_collection(transformed_name, elem_) def transform_op_if_inside_handler(info, op, keep_if_possible=True): @@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True): # Transform inputs: inputs_ = [info.transformer._transform_t(t) for t in op.inputs] + # Leave inputs empty if a graph cycle was found. + if None in inputs_: + info.cyclic_ops.append(op) + inputs_ = [] + # Clone the node def: node_def_ = deepcopy(op._node_def) @@ -239,7 +251,7 @@ class Transformer(object): self.transformed_ts = {} self.collections = dict((key, self.graph.get_collection(key)) for key in self.graph.get_all_collection_keys()) - + self.cyclic_ops = [] class ResultInfo(object): """"Contains information about the result of a transform operation.""" @@ -452,6 +464,17 @@ class Transformer(object): for op in remaining_roots: self._transform_op(op) + # Finalize cyclic ops: + for op in self._info.cyclic_ops: + logging.debug("Finalizing cyclic op: %s", op.name) + op_ = self._info.transformed_ops[op] + inputs_ = [self._info.transformed_ts[t] for t in op.inputs] + if None in inputs_: + raise ValueError("Could not find all the inputs of cyclic op: {}" + .format(op_.name)) + for input_id, t_ in enumerate(inputs_): + op_._update_input(input_id, t_) # pylint: disable=protected-access + sgv_ = self._transform_sgv(sgv) res_info = Transformer.ResultInfo(self._info) @@ -506,9 +529,13 @@ class Transformer(object): Returns: The transformed tensor. """ + logging.debug("Transforming tensor: %s", t.name) if t in self._info.transformed_ts: return self._info.transformed_ts[t] + # Mark as None to detect cycle. + self._info.transformed_ts[t] = None + op, op_index = t.op, t.value_index # If op is not in the subgraph: diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 11ee2435c9..d8824f6792 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import re from six import iteritems from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops as tf_array_ops @@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): """ return tf_array_ops.placeholder( dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + + +_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") + + +def get_predefined_collection_names(): + """Return all the predefined collection names.""" + return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) + if not _INTERNAL_VARIABLE_RE.match(key)] + + +def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): + """Find corresponding op/tensor in a different graph. + + Args: + target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `target` name. + + Returns: + The corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + src_name = target.name + if src_scope: + src_scope = scope_finalize(src_scope) + if not src_name.startswidth(src_scope): + raise ValueError("{} does not start with {}".format(src_name, src_scope)) + src_name = src_name[len(src_scope):] + + dst_name = src_name + if dst_scope: + dst_scope = scope_finalize(dst_scope) + dst_name = dst_scope + dst_name + + if isinstance(target, tf_ops.Tensor): + return dst_graph.get_tensor_by_name(dst_name) + if isinstance(target, tf_ops.Operation): + return dst_graph.get_operation_by_name(dst_name) + raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) + + +def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): + """Find corresponding ops/tensors in a different graph. + + `targets` is a Python tree, that is, a nested structure of iterable + (list, tupple, dictionary) whose leaves are instances of + `tf.Tensor` or `tf.Operation` + + Args: + targets: A Python tree containing `tf.Tensor` or `tf.Operation` + belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `top` name. + + Returns: + A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + def func(top): + return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) + return transform_tree(targets, func) |