diff options
Diffstat (limited to 'tensorflow/contrib/graph_editor/util.py')
-rw-r--r-- | tensorflow/contrib/graph_editor/util.py | 73 |
1 files changed, 73 insertions, 0 deletions
diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 11ee2435c9..d8824f6792 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import collections +import re from six import iteritems from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops as tf_array_ops @@ -465,3 +466,75 @@ def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None): """ return tf_array_ops.placeholder( dtype=dtype, shape=shape, name=placeholder_name(scope=scope)) + + +_INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") + + +def get_predefined_collection_names(): + """Return all the predefined collection names.""" + return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) + if not _INTERNAL_VARIABLE_RE.match(key)] + + +def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): + """Find corresponding op/tensor in a different graph. + + Args: + target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `target` name. + + Returns: + The corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + src_name = target.name + if src_scope: + src_scope = scope_finalize(src_scope) + if not src_name.startswidth(src_scope): + raise ValueError("{} does not start with {}".format(src_name, src_scope)) + src_name = src_name[len(src_scope):] + + dst_name = src_name + if dst_scope: + dst_scope = scope_finalize(dst_scope) + dst_name = dst_scope + dst_name + + if isinstance(target, tf_ops.Tensor): + return dst_graph.get_tensor_by_name(dst_name) + if isinstance(target, tf_ops.Operation): + return dst_graph.get_operation_by_name(dst_name) + raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) + + +def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): + """Find corresponding ops/tensors in a different graph. + + `targets` is a Python tree, that is, a nested structure of iterable + (list, tupple, dictionary) whose leaves are instances of + `tf.Tensor` or `tf.Operation` + + Args: + targets: A Python tree containing `tf.Tensor` or `tf.Operation` + belonging to the original graph. + dst_graph: The graph in which the corresponding graph element must be found. + dst_scope: A scope which is prepended to the name to look for. + src_scope: A scope which is removed from the original of `top` name. + + Returns: + A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. + + Raises: + ValueError: if `src_name` does not start with `src_scope`. + TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` + KeyError: If the corresponding graph element cannot be found. + """ + def func(top): + return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) + return transform_tree(targets, func) |