# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Utility functions for the graph_editor. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import re from six import iteritems from tensorflow.python.framework import ops as tf_ops from tensorflow.python.ops import array_ops as tf_array_ops __all__ = [ "make_list_of_op", "get_tensors", "make_list_of_t", "get_generating_ops", "get_consuming_ops", "ControlOutputs", "placeholder_name", "make_placeholder_from_tensor", "make_placeholder_from_dtype_and_shape", ] # The graph editor sometimes need to create placeholders, they are named # "geph_*". "geph" stands for Graph-Editor PlaceHolder. _DEFAULT_PLACEHOLDER_PREFIX = "geph" def concatenate_unique(la, lb): """Add all the elements of `lb` to `la` if they are not there already. The elements added to `la` maintain ordering with respect to `lb`. Args: la: List of Python objects. lb: List of Python objects. Returns: `la`: The list `la` with missing elements from `lb`. """ la_set = set(la) for l in lb: if l not in la_set: la.append(l) la_set.add(l) return la # TODO(fkp): very generic code, it should be moved in a more generic place. class ListView(object): """Immutable list wrapper. This class is strongly inspired by the one in tf.Operation. """ def __init__(self, list_): if not isinstance(list_, list): raise TypeError("Expected a list, got: {}.".format(type(list_))) self._list = list_ def __iter__(self): return iter(self._list) def __len__(self): return len(self._list) def __bool__(self): return bool(self._list) # Python 3 wants __bool__, Python 2.7 wants __nonzero__ __nonzero__ = __bool__ def __getitem__(self, i): return self._list[i] def __add__(self, other): if not isinstance(other, list): other = list(other) return list(self) + other # TODO(fkp): very generic code, it should be moved in a more generic place. def is_iterable(obj): """Return true if the object is iterable.""" if isinstance(obj, tf_ops.Tensor): return False try: _ = iter(obj) except Exception: # pylint: disable=broad-except return False return True def flatten_tree(tree, leaves=None): """Flatten a tree into a list. Args: tree: iterable or not. If iterable, its elements (child) can also be iterable or not. leaves: list to which the tree leaves are appended (None by default). Returns: A list of all the leaves in the tree. """ if leaves is None: leaves = [] if isinstance(tree, dict): for _, child in iteritems(tree): flatten_tree(child, leaves) elif is_iterable(tree): for child in tree: flatten_tree(child, leaves) else: leaves.append(tree) return leaves def transform_tree(tree, fn, iterable_type=tuple): """Transform all the nodes of a tree. Args: tree: iterable or not. If iterable, its elements (child) can also be iterable or not. fn: function to apply to each leaves. iterable_type: type use to construct the resulting tree for unknown iterable, typically `list` or `tuple`. Returns: A tree whose leaves has been transformed by `fn`. The hierarchy of the output tree mimics the one of the input tree. """ if is_iterable(tree): if isinstance(tree, dict): res = tree.__new__(type(tree)) res.__init__( (k, transform_tree(child, fn)) for k, child in iteritems(tree)) return res elif isinstance(tree, tuple): # NamedTuple? if hasattr(tree, "_asdict"): res = tree.__new__(type(tree), **transform_tree(tree._asdict(), fn)) else: res = tree.__new__(type(tree), (transform_tree(child, fn) for child in tree)) return res elif isinstance(tree, collections.Sequence): res = tree.__new__(type(tree)) res.__init__(transform_tree(child, fn) for child in tree) return res else: return iterable_type(transform_tree(child, fn) for child in tree) else: return fn(tree) def check_graphs(*args): """Check that all the element in args belong to the same graph. Args: *args: a list of object with a obj.graph property. Raises: ValueError: if all the elements do not belong to the same graph. """ graph = None for i, sgv in enumerate(args): if graph is None and sgv.graph is not None: graph = sgv.graph elif sgv.graph is not None and sgv.graph is not graph: raise ValueError("Argument[{}]: Wrong graph!".format(i)) def get_unique_graph(tops, check_types=None, none_if_empty=False): """Return the unique graph used by the all the elements in tops. Args: tops: list of elements to check (usually a list of tf.Operation and/or tf.Tensor). Or a tf.Graph. check_types: check that the element in tops are of given type(s). If None, the types (tf.Operation, tf.Tensor) are used. none_if_empty: don't raise an error if tops is an empty list, just return None. Returns: The unique graph used by all the tops. Raises: TypeError: if tops is not a iterable of tf.Operation. ValueError: if the graph is not unique. """ if isinstance(tops, tf_ops.Graph): return tops if not is_iterable(tops): raise TypeError("{} is not iterable".format(type(tops))) if check_types is None: check_types = (tf_ops.Operation, tf_ops.Tensor) elif not is_iterable(check_types): check_types = (check_types,) g = None for op in tops: if not isinstance(op, check_types): raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str( t) for t in check_types]), type(op))) if g is None: g = op.graph elif g is not op.graph: raise ValueError("Operation {} does not belong to given graph".format(op)) if g is None and not none_if_empty: raise ValueError("Can't find the unique graph of an empty list") return g def make_list_of_op(ops, check_graph=True, allow_graph=True, ignore_ts=False): """Convert ops to a list of `tf.Operation`. Args: ops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single operation. check_graph: if `True` check if all the operations belong to the same graph. allow_graph: if `False` a `tf.Graph` cannot be converted. ignore_ts: if True, silently ignore `tf.Tensor`. Returns: A newly created list of `tf.Operation`. Raises: TypeError: if ops cannot be converted to a list of `tf.Operation` or, if `check_graph` is `True`, if all the ops do not belong to the same graph. """ if isinstance(ops, tf_ops.Graph): if allow_graph: return ops.get_operations() else: raise TypeError("allow_graph is False: cannot convert a tf.Graph.") else: if not is_iterable(ops): ops = [ops] if not ops: return [] if check_graph: check_types = None if ignore_ts else tf_ops.Operation get_unique_graph(ops, check_types=check_types) return [op for op in ops if isinstance(op, tf_ops.Operation)] # TODO(fkp): move this function in tf.Graph? def get_tensors(graph): """get all the tensors which are input or output of an op in the graph. Args: graph: a `tf.Graph`. Returns: A list of `tf.Tensor`. Raises: TypeError: if graph is not a `tf.Graph`. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a graph, got: {}".format(type(graph))) ts = [] for op in graph.get_operations(): ts += op.outputs return ts def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False): """Convert ts to a list of `tf.Tensor`. Args: ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor. check_graph: if `True` check if all the tensors belong to the same graph. allow_graph: if `False` a `tf.Graph` cannot be converted. ignore_ops: if `True`, silently ignore `tf.Operation`. Returns: A newly created list of `tf.Tensor`. Raises: TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or, if `check_graph` is `True`, if all the ops do not belong to the same graph. """ if isinstance(ts, tf_ops.Graph): if allow_graph: return get_tensors(ts) else: raise TypeError("allow_graph is False: cannot convert a tf.Graph.") else: if not is_iterable(ts): ts = [ts] if not ts: return [] if check_graph: check_types = None if ignore_ops else tf_ops.Tensor get_unique_graph(ts, check_types=check_types) return [t for t in ts if isinstance(t, tf_ops.Tensor)] def get_generating_ops(ts): """Return all the generating ops of the tensors in `ts`. Args: ts: a list of `tf.Tensor` Returns: A list of all the generating `tf.Operation` of the tensors in `ts`. Raises: TypeError: if `ts` cannot be converted to a list of `tf.Tensor`. """ ts = make_list_of_t(ts, allow_graph=False) return [t.op for t in ts] def get_consuming_ops(ts): """Return all the consuming ops of the tensors in ts. Args: ts: a list of `tf.Tensor` Returns: A list of all the consuming `tf.Operation` of the tensors in `ts`. Raises: TypeError: if ts cannot be converted to a list of `tf.Tensor`. """ ts = make_list_of_t(ts, allow_graph=False) ops = [] for t in ts: for op in t.consumers(): if op not in ops: ops.append(op) return ops class ControlOutputs(object): """The control outputs topology.""" def __init__(self, graph): """Create a dictionary of control-output dependencies. Args: graph: a `tf.Graph`. Returns: A dictionary where a key is a `tf.Operation` instance and the corresponding value is a list of all the ops which have the key as one of their control-input dependencies. Raises: TypeError: graph is not a `tf.Graph`. """ if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) self._control_outputs = {} self._graph = graph self._version = None self._build() def update(self): """Update the control outputs if the graph has changed.""" if self._version != self._graph.version: self._build() return self def _build(self): """Build the control outputs dictionary.""" self._control_outputs.clear() ops = self._graph.get_operations() for op in ops: for control_input in op.control_inputs: if control_input not in self._control_outputs: self._control_outputs[control_input] = [] if op not in self._control_outputs[control_input]: self._control_outputs[control_input].append(op) self._version = self._graph.version def get_all(self): return self._control_outputs def get(self, op): """return the control outputs of op.""" if op in self._control_outputs: return self._control_outputs[op] else: return () @property def graph(self): return self._graph def scope_finalize(scope): if scope and scope[-1] != "/": scope += "/" return scope def scope_dirname(scope): slash = scope.rfind("/") if slash == -1: return "" return scope[:slash + 1] def scope_basename(scope): slash = scope.rfind("/") if slash == -1: return scope return scope[slash + 1:] def placeholder_name(t=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create placeholder name for the graph editor. Args: t: optional tensor on which the placeholder operation's name will be based on scope: absolute scope with which to prefix the placeholder's name. None means that the scope of t is preserved. "" means the root scope. prefix: placeholder name prefix. Returns: A new placeholder name prefixed by "geph". Note that "geph" stands for Graph Editor PlaceHolder. This convention allows to quickly identify the placeholder generated by the Graph Editor. Raises: TypeError: if t is not None or a tf.Tensor. """ if scope is not None: scope = scope_finalize(scope) if t is not None: if not isinstance(t, tf_ops.Tensor): raise TypeError("Expected a tf.Tenfor, got: {}".format(type(t))) op_dirname = scope_dirname(t.op.name) op_basename = scope_basename(t.op.name) if scope is None: scope = op_dirname if op_basename.startswith("{}__".format(prefix)): ph_name = op_basename else: ph_name = "{}__{}_{}".format(prefix, op_basename, t.value_index) return scope + ph_name else: if scope is None: scope = "" return "{}{}".format(scope, prefix) def make_placeholder_from_tensor(t, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a `tf.placeholder` for the Graph Editor. Note that the correct graph scope must be set by the calling function. Args: t: a `tf.Tensor` whose name will be used to create the placeholder (see function placeholder_name). scope: absolute scope within which to create the placeholder. None means that the scope of `t` is preserved. `""` means the root scope. prefix: placeholder name prefix. Returns: A newly created `tf.placeholder`. Raises: TypeError: if `t` is not `None` or a `tf.Tensor`. """ return tf_array_ops.placeholder( dtype=t.dtype, shape=t.get_shape(), name=placeholder_name(t, scope=scope, prefix=prefix)) def make_placeholder_from_dtype_and_shape(dtype, shape=None, scope=None, prefix=_DEFAULT_PLACEHOLDER_PREFIX): """Create a tf.placeholder for the Graph Editor. Note that the correct graph scope must be set by the calling function. The placeholder is named using the function placeholder_name (with no tensor argument). Args: dtype: the tensor type. shape: the tensor shape (optional). scope: absolute scope within which to create the placeholder. None means that the scope of t is preserved. "" means the root scope. prefix: placeholder name prefix. Returns: A newly created tf.placeholder. """ return tf_array_ops.placeholder( dtype=dtype, shape=shape, name=placeholder_name(scope=scope, prefix=prefix)) _INTERNAL_VARIABLE_RE = re.compile(r"^__\w+__$") def get_predefined_collection_names(): """Return all the predefined collection names.""" return [getattr(tf_ops.GraphKeys, key) for key in dir(tf_ops.GraphKeys) if not _INTERNAL_VARIABLE_RE.match(key)] def find_corresponding_elem(target, dst_graph, dst_scope="", src_scope=""): """Find corresponding op/tensor in a different graph. Args: target: A `tf.Tensor` or a `tf.Operation` belonging to the original graph. dst_graph: The graph in which the corresponding graph element must be found. dst_scope: A scope which is prepended to the name to look for. src_scope: A scope which is removed from the original of `target` name. Returns: The corresponding tf.Tensor` or a `tf.Operation`. Raises: ValueError: if `src_name` does not start with `src_scope`. TypeError: if `target` is not a `tf.Tensor` or a `tf.Operation` KeyError: If the corresponding graph element cannot be found. """ src_name = target.name if src_scope: src_scope = scope_finalize(src_scope) if not src_name.startswidth(src_scope): raise ValueError("{} does not start with {}".format(src_name, src_scope)) src_name = src_name[len(src_scope):] dst_name = src_name if dst_scope: dst_scope = scope_finalize(dst_scope) dst_name = dst_scope + dst_name if isinstance(target, tf_ops.Tensor): return dst_graph.get_tensor_by_name(dst_name) if isinstance(target, tf_ops.Operation): return dst_graph.get_operation_by_name(dst_name) raise TypeError("Expected tf.Tensor or tf.Operation, got: {}", type(target)) def find_corresponding(targets, dst_graph, dst_scope="", src_scope=""): """Find corresponding ops/tensors in a different graph. `targets` is a Python tree, that is, a nested structure of iterable (list, tupple, dictionary) whose leaves are instances of `tf.Tensor` or `tf.Operation` Args: targets: A Python tree containing `tf.Tensor` or `tf.Operation` belonging to the original graph. dst_graph: The graph in which the corresponding graph element must be found. dst_scope: A scope which is prepended to the name to look for. src_scope: A scope which is removed from the original of `top` name. Returns: A Python tree containin the corresponding tf.Tensor` or a `tf.Operation`. Raises: ValueError: if `src_name` does not start with `src_scope`. TypeError: if `top` is not a `tf.Tensor` or a `tf.Operation` KeyError: If the corresponding graph element cannot be found. """ def func(top): return find_corresponding_elem(top, dst_graph, dst_scope, src_scope) return transform_tree(targets, func)