aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-01-18 05:53:24 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-01-18 06:06:32 -0800
commite28da73a92273748c768abbb72502f56e93b9ce3 (patch)
tree995ff30543d46c76196f14512af775a585203cc7 /tensorflow/contrib/graph_editor
parent1c7ef3db0840d6530e7dde693b9351b8205f7e84 (diff)
Fix transform for cyclic graph.
Improve collection name handling. Added helper to retrieve corresponding tensor/op. Change: 144824657
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/transform.py37
-rw-r--r--tensorflow/contrib/graph_editor/util.py73
2 files changed, 105 insertions, 5 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 6fb347c834..832698b8a0 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -26,13 +26,13 @@ from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
-
from tensorflow.contrib.graph_editor import edit
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.platform import tf_logging as logging
__all__ = [
"replace_t_with_placeholder_handler",
@@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t):
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
+ A collection is renamed only if is not a known key, as described in
+ `tf.GraphKeys`.
+
Args:
info: Transform._Info instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
- # TODO(fkp): handle known special cases
+ known_collection_names = util.get_predefined_collection_names()
for name, collection in iteritems(info.collections):
if elem not in collection:
continue
- collection_name_ = info.transformer.new_name(name)
- info.graph_.add_to_collection(collection_name_, elem_)
+
+ if name in known_collection_names:
+ transformed_name = name
+ else:
+ transformed_name = info.transformer.new_name(name)
+ info.graph_.add_to_collection(transformed_name, elem_)
def transform_op_if_inside_handler(info, op, keep_if_possible=True):
@@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True):
# Transform inputs:
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
+ # Leave inputs empty if a graph cycle was found.
+ if None in inputs_:
+ info.cyclic_ops.append(op)
+ inputs_ = []
+
# Clone the node def:
node_def_ = deepcopy(op._node_def)
@@ -239,7 +251,7 @@ class Transformer(object):
self.transformed_ts = {}
self.collections = dict((key, self.graph.get_collection(key))
for key in self.graph.get_all_collection_keys())
-
+ self.cyclic_ops = []
class ResultInfo(object):
""""Contains information about the result of a transform operation."""
@@ -452,6 +464,17 @@ class Transformer(object):
for op in remaining_roots:
self._transform_op(op)
+ # Finalize cyclic ops:
+ for op in self._info.cyclic_ops:
+ logging.debug("Finalizing cyclic op: %s", op.name)
+ op_ = self._info.transformed_ops[op]
+ inputs_ = [self._info.transformed_ts[t] for t in op.inputs]
+ if None in inputs_:
+ raise ValueError("Could not find all the inputs of cyclic op: {}"
+ .format(op_.name))
+ for input_id, t_ in enumerate(inputs_):
+ op_._update_input(input_id, t_) # pylint: disable=protected-access
+
sgv_ = self._transform_sgv(sgv)
res_info = Transformer.ResultInfo(self._info)
@@ -506,9 +529,13 @@ class Transformer(object):
Returns:
The transformed tensor.
"""
+ logging.debug("Transforming tensor: %s", t.name)
if t in self._info.transformed_ts:
return self._info.transformed_ts[t]
+ # Mark as None to detect cycle.
+ self._info.transformed_ts[t] = None
+
op, op_index = t.op, t.value_index
# If op is not in the subgraph:
diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py
index 11ee2435c9..d8824f6792 100644
--- a/tensorflow/contrib/graph_editor/util.py
+++ b/tensorflow/contrib/graph_editor/util.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import re
from six import iteritems
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops as tf_array_ops
@@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
+
+
+_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
+
+
+def get_predefined_collection_names():
+ """Return all the predefined collection names."""
+ return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
+ if not _INTERNAL_VARIABLE_RE.match(key)]
+
+
+def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding op/tensor in a different graph.
+
+ Args:
+ target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `target` name.
+
+ Returns:
+ The corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ src_name = target.name
+ if src_scope:
+ src_scope = scope_finalize(src_scope)
+ if not src_name.startswidth(src_scope):
+ raise ValueError("{} does not start with {}".format(src_name, src_scope))
+ src_name = src_name[len(src_scope):]
+
+ dst_name = src_name
+ if dst_scope:
+ dst_scope = scope_finalize(dst_scope)
+ dst_name = dst_scope + dst_name
+
+ if isinstance(target, tf_ops.Tensor):
+ return dst_graph.get_tensor_by_name(dst_name)
+ if isinstance(target, tf_ops.Operation):
+ return dst_graph.get_operation_by_name(dst_name)
+ raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
+
+
+def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding ops/tensors in a different graph.
+
+ `targets` is a Python tree, that is, a nested structure of iterable
+ (list, tupple, dictionary) whose leaves are instances of
+ `tf.Tensor` or `tf.Operation`
+
+ Args:
+ targets: A Python tree containing `tf.Tensor` or `tf.Operation`
+ belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `top` name.
+
+ Returns:
+ A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ def func(top):
+ return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
+ return transform_tree(targets, func)