aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/graph_editor/util.py')
-rw-r--r--tensorflow/contrib/graph_editor/util.py73
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py
index 11ee2435c9..d8824f6792 100644
--- a/tensorflow/contrib/graph_editor/util.py
+++ b/tensorflow/contrib/graph_editor/util.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import re
from six import iteritems
from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.ops import array_ops as tf_array_ops
@@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None):
"""
return tf_array_ops.placeholder(
dtype=dtype, shape=shape, name=placeholder_name(scope=scope))
+
+
+_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$")
+
+
+def get_predefined_collection_names():
+ """Return all the predefined collection names."""
+ return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys)
+ if not _INTERNAL_VARIABLE_RE.match(key)]
+
+
+def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding op/tensor in a different graph.
+
+ Args:
+ target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `target` name.
+
+ Returns:
+ The corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ src_name = target.name
+ if src_scope:
+ src_scope = scope_finalize(src_scope)
+ if not src_name.startswidth(src_scope):
+ raise ValueError("{} does not start with {}".format(src_name, src_scope))
+ src_name = src_name[len(src_scope):]
+
+ dst_name = src_name
+ if dst_scope:
+ dst_scope = scope_finalize(dst_scope)
+ dst_name = dst_scope + dst_name
+
+ if isinstance(target, tf_ops.Tensor):
+ return dst_graph.get_tensor_by_name(dst_name)
+ if isinstance(target, tf_ops.Operation):
+ return dst_graph.get_operation_by_name(dst_name)
+ raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target))
+
+
+def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""):
+ """Find corresponding ops/tensors in a different graph.
+
+ `targets` is a Python tree, that is, a nested structure of iterable
+ (list, tupple, dictionary) whose leaves are instances of
+ `tf.Tensor` or `tf.Operation`
+
+ Args:
+ targets: A Python tree containing `tf.Tensor` or `tf.Operation`
+ belonging to the original graph.
+ dst_graph: The graph in which the corresponding graph element must be found.
+ dst_scope: A scope which is prepended to the name to look for.
+ src_scope: A scope which is removed from the original of `top` name.
+
+ Returns:
+ A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`.
+
+ Raises:
+ ValueError: if `src_name` does not start with `src_scope`.
+ TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation`
+ KeyError: If the corresponding graph element cannot be found.
+ """
+ def func(top):
+ return find_corresponding_elem(top, dst_graph, dst_scope, src_scope)
+ return transform_tree(targets, func)