aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-08-08 01:05:06 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-08 02:17:22 -0700
commit5e2019571a0c09bd627dd4f0abbf6b08e19a65e5 (patch)
tree3db31d03a2ecb20e04e232e9d2e4df292a1aacb1 /tensorflow/contrib/graph_editor
parent50c8d6c9803c05234ad98355ef11d8cf98a5e7ae (diff)
Improved documentation and minor bug fixes.
Change: 129607856
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r--tensorflow/contrib/graph_editor/README.md15
-rw-r--r--tensorflow/contrib/graph_editor/__init__.py101
-rw-r--r--tensorflow/contrib/graph_editor/edit.py9
-rw-r--r--tensorflow/contrib/graph_editor/select.py3
-rw-r--r--tensorflow/contrib/graph_editor/subgraph.py60
-rw-r--r--tensorflow/contrib/graph_editor/transform.py104
-rw-r--r--tensorflow/contrib/graph_editor/util.py10
7 files changed, 226 insertions, 76 deletions
diff --git a/tensorflow/contrib/graph_editor/README.md b/tensorflow/contrib/graph_editor/README.md
index e237208752..f6f82668ed 100644
--- a/tensorflow/contrib/graph_editor/README.md
+++ b/tensorflow/contrib/graph_editor/README.md
@@ -1,15 +1,6 @@
# TensorFlow Graph Editor
-The TensorFlow Graph Editor libray which allows for modification of an existing
-tf.Graph instance.
+The TensorFlow Graph Editor library allows for modification of an existing
+tf.Graph instance in-place.
-## Overview of the modules
-
-* util.py: utility functions
-* select.py: selection functions, allowing for various selection method of
- tensors and operations.
-* subgraph.py: the SubGraphView class, allowing to manipulate subgraph of a
- TensorFlow tf.Graph instance.
-* transform.py: the Transformer class, allowing to tranform a subgraph into
- another one.
-* edit.py: various editing function operating on subgraph.
+The author's github username is [purpledog](https://github.com/purpledog).
diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py
index 0cbfb5420e..7f20bd88dd 100644
--- a/tensorflow/contrib/graph_editor/__init__.py
+++ b/tensorflow/contrib/graph_editor/__init__.py
@@ -12,7 +12,103 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Graph editor module allows to modify an existing graph in place.
+"""# TensorFlow Graph Editor.
+
+The TensorFlow Graph Editor library allows for modification of an existing
+tf.Graph instance in-place.
+
+The author's github username is [purpledog](https://github.com/purpledog).
+
+## Library overview
+
+Appending new nodes is the only graph editing operation allowed by the
+TensorFlow core library. The Graph Editor library is an attempt to allow for
+other kinds of editing operations, namely, *rerouting* and *transforming*.
+
+* *rerouting* is a local operation consisting in re-plugging existing tensors
+ (the edges of the graph). Operations (the nodes) are not modified by this
+ operation. For example, rerouting can be used to insert an operation adding
+ noise in place of an existing tensor.
+* *transforming* is a global operation consisting in transforming a graph into
+ another. By default, a transformation is a simple copy but it can be
+ customized to achieved other goals. For instance, a graph can be transformed
+ into another one in which noise is added after all the operations of a
+ specific type.
+
+**Important: modifying a graph in-place with the Graph Editor must be done
+`offline`, that is, without any active sessions.**
+
+Of course new operations can be appended online but Graph Editor specific
+operations like rerouting and transforming can currently only be done offline.
+
+Here is an example of what you **cannot** do:
+
+* Build a graph.
+* Create a session and run the graph.
+* Modify the graph with the Graph Editor.
+* Re-run the graph with the `same` previously created session.
+
+To edit an already running graph, follow these steps:
+
+* Build a graph.
+* Create a session and run the graph.
+* Save the graph state and terminate the session
+* Modify the graph with the Graph Editor.
+* create a new session and restore the graph state
+* Re-run the graph with the newly created session.
+
+Note that this procedure is very costly because a new session must be created
+after any modifications. Among other things, it takes time because the entire
+graph state must be saved and restored again.
+
+### Sub-graph
+
+Most of the functions in the Graph Editor library operate on *sub-graph*.
+More precisely, they take as input arguments instances of the SubGraphView class
+(or anything which can be converted to it). Doing so allows the same function
+to transparently operate on single operations as well as sub-graph of any size.
+
+A subgraph can be created in several ways:
+
+* using a list of ops:
+
+```python
+my_sgv = ge.sgv(ops)
+```
+
+* from a name scope:
+
+```python
+my_sgv = ge.sgv_scope("foo/bar", graph=tf.get_default_graph())
+```
+
+* using regular expression:
+
+```python
+my_sgv = ge.sgv("foo/.*/.*read$", graph=tf.get_default_graph())
+```
+
+Note the Graph Editor is meant to manipulate several graphs at the same time,
+typically during transform or copy operation. For that reason,
+to avoid any confusion, the default graph is never used and the graph on
+which to operate must always be explicitely given. This is the reason why
+*graph=tf.get_default_graph()* is used in the code snippets above.
+
+
+### Modules
+
+* util: utility functions.
+* select: various selection methods of TensorFlow tensors and operations.
+* match: TensorFlow graph matching. Think of this as regular expressions for
+ graphs (but not quite yet).
+* reroute: various ways of rerouting tensors to different consuming ops like
+ *swap* or *reroute_a2b*.
+* subgraph: the SubGraphView class, which enables subgraph manipulations in a
+ TensorFlow tf.Graph.
+* edit: various editing functions operating on subgraphs like *detach*,
+ *connect* or *bypass*.
+* transform: the Transformer class, which enables transforming
+ (or simply copying) a subgraph into another one.
"""
from __future__ import absolute_import
@@ -54,6 +150,8 @@ from tensorflow.contrib.graph_editor.subgraph import SubGraphView
from tensorflow.contrib.graph_editor.transform import copy
from tensorflow.contrib.graph_editor.transform import Transformer
+from tensorflow.contrib.graph_editor.util import ControlOutputs
+
# some useful aliases
ph = util.make_placeholder_from_dtype_and_shape
@@ -62,4 +160,3 @@ sgv_scope = subgraph.make_view_from_scope
ts = select.select_ts
ops = select.select_ops
matcher = match.OpMatcher
-
diff --git a/tensorflow/contrib/graph_editor/edit.py b/tensorflow/contrib/graph_editor/edit.py
index 9d1518a765..b1c41a2af2 100644
--- a/tensorflow/contrib/graph_editor/edit.py
+++ b/tensorflow/contrib/graph_editor/edit.py
@@ -68,6 +68,7 @@ def detach_inputs(sgv, control_inputs=False):
Returns:
A new subgraph view of the detached subgraph.
Note that sgv is also modified in place.
+ A list of the created input placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
@@ -98,6 +99,7 @@ def detach_outputs(sgv, control_outputs=None):
Returns:
A new subgraph view of the detached subgraph.
Note that sgv is also modified in place.
+ A list of the created output placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
@@ -141,6 +143,8 @@ def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None):
Returns:
A new subgraph view of the detached subgraph.
Note that sgv is also modified in place.
+ A list of the created input placeholders.
+ A list of the created output placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
@@ -164,8 +168,8 @@ def connect(sgv0, sgv1, disconnect_first=False):
subgraph.make_view.
disconnect_first: if True the current outputs of sgv0 are disconnected.
Returns:
- Two new subgraph views (now connected). sgv0 and svg1 are also modified
- in place.
+ The modified sgv0 (now connected to sgv1).
+ The modified sgv1 (now connected to sgv0).
Raises:
StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
@@ -189,6 +193,7 @@ def bypass(sgv):
Returns:
A new subgraph view of the bypassed subgraph.
Note that sgv is also modified in place.
+ A list of the created input placeholders.
Raises:
StandardError: if sgv cannot be converted to a SubGraphView using
the same rules than the function subgraph.make_view.
diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py
index f673dcec35..5f8a50a7d8 100644
--- a/tensorflow/contrib/graph_editor/select.py
+++ b/tensorflow/contrib/graph_editor/select.py
@@ -20,7 +20,6 @@ from __future__ import division
from __future__ import print_function
import re
-import types
from six import iteritems
from six import string_types
@@ -299,7 +298,7 @@ def compute_boundary_ts(ops, ambiguous_are_outputs=True):
def get_within_boundary_ops(ops,
seed_ops,
- boundary_ops,
+ boundary_ops=(),
inclusive=True,
control_inputs=False,
control_outputs=None,
diff --git a/tensorflow/contrib/graph_editor/subgraph.py b/tensorflow/contrib/graph_editor/subgraph.py
index d8dab07427..ac28a114af 100644
--- a/tensorflow/contrib/graph_editor/subgraph.py
+++ b/tensorflow/contrib/graph_editor/subgraph.py
@@ -164,25 +164,30 @@ class SubGraphView(object):
TypeError: if inside_ops cannot be converted to a list of tf.Operation or
if passthrough_ts cannot be converted to a list of tf.Tensor.
"""
+
inside_ops = util.make_list_of_op(inside_ops)
passthrough_ts = util.make_list_of_t(passthrough_ts)
ops_and_ts = inside_ops + passthrough_ts
if ops_and_ts:
self._graph = util.get_unique_graph(ops_and_ts)
- else:
- self._graph = None
- self._ops = inside_ops
+ self._ops = inside_ops
- # Compute inside and outside tensor
- inputs, outputs, insides = select.compute_boundary_ts(inside_ops)
+ # Compute inside and outside tensor
+ inputs, outputs, insides = select.compute_boundary_ts(inside_ops)
- # Compute passthrough tensors, silently ignoring the non-passthrough ones.
- all_tensors = frozenset(inputs + outputs + list(insides))
- self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]
+ # Compute passthrough tensors, silently ignoring the non-passthrough ones.
+ all_tensors = frozenset(inputs + outputs + list(insides))
+ self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors]
- # Set inputs and outputs.
- self._input_ts = inputs + self._passthrough_ts
- self._output_ts = outputs + self._passthrough_ts
+ # Set inputs and outputs.
+ self._input_ts = inputs + self._passthrough_ts
+ self._output_ts = outputs + self._passthrough_ts
+ else:
+ self._graph = None
+ self._passthrough_ts = []
+ self._input_ts = []
+ self._output_ts = []
+ self._ops = []
def __copy__(self):
"""Create a copy of this subgraph.
@@ -412,23 +417,27 @@ class SubGraphView(object):
raise AssertionError("More than 1 op named: {}!".format(op_name))
return res[0]
- def __getitem__(self, op_name):
- return self.find_op_by_name(op_name)
-
def __str__(self):
- res = StringIO()
+ if not self:
+ return "SubGraphView: empty"
+ def op_name(op):
+ return op.name
def tensor_name(t):
if t in self._passthrough_ts:
return "{} *".format(t.name)
else:
return t.name
- print("SubGraphView:", file=res)
- print("** ops:", file=res)
- print("\n".join([op.name for op in self._ops]), file=res)
- print("** inputs:", file=res)
- print("\n".join([tensor_name(t) for t in self._input_ts]), file=res)
- print("** outputs:", file=res)
- print("\n".join([tensor_name(t) for t in self._output_ts]), file=res)
+ def print_list(name, iterable, get_name):
+ if iterable:
+ print("** {}[{}]:".format(name, len(iterable)), file=res)
+ print("\n".join([get_name(elem) for elem in iterable]), file=res)
+ else:
+ print("** {}: empty".format(name), file=res)
+ res = StringIO()
+ print("SubGraphView (graphid={}):".format(id(self.graph)), file=res)
+ print_list("ops", self._ops, op_name)
+ print_list("inputs", self._input_ts, tensor_name)
+ print_list("outputs", self._output_ts, tensor_name)
return res.getvalue()
@property
@@ -466,10 +475,13 @@ class SubGraphView(object):
"""The passthrough tensors, going straight from input to output."""
return util.ListView(self._passthrough_ts)
- def __nonzero__(self):
+ def __bool__(self):
"""Allows for implicit boolean conversion."""
return self._graph is not None
+ # Python 3 wants __bool__, Python 2.7 wants __nonzero__
+ __nonzero__ = __bool__
+
def op(self, op_id):
"""Get an op by its index."""
return self._ops[op_id]
@@ -556,7 +568,7 @@ def _check_graph(sgv, graph):
"""
if not isinstance(sgv, SubGraphView):
raise TypeError("Expected a SubGraphView, got: {}".format(type(graph)))
- if graph is None or sgv.graph is None:
+ if graph is None or not sgv.graph:
return sgv
if not isinstance(graph, tf_ops.Graph):
raise TypeError("Expected a tf.Graph, got: {}".format(type(graph)))
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 8214c19674..65e49a3f71 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -31,7 +31,7 @@ from tensorflow.python.framework import ops as tf_ops
from tensorflow.python.platform import tf_logging as logging
-def transform_tensor_into_placeholder_handler(info, t):
+def replace_t_with_placeholder_handler(info, t):
"""Transform a tensor into a placeholder tensor.
This handler is typically used to transform a subgraph input tensor into a
@@ -48,7 +48,7 @@ def transform_tensor_into_placeholder_handler(info, t):
return t_
-def keep_same_tensor_if_possible_handler(info, t):
+def keep_t_if_possible_handler(info, t):
"""Transform a tensor into itself (identity) if possible.
This handler transform a tensor into itself if the source and destination
@@ -64,7 +64,7 @@ def keep_same_tensor_if_possible_handler(info, t):
if info.graph is info.graph_:
return t
else:
- return transform_tensor_into_placeholder_handler(info, t)
+ return replace_t_with_placeholder_handler(info, t)
def assign_renamed_collections_handler(info, elem, elem_):
@@ -83,7 +83,7 @@ def assign_renamed_collections_handler(info, elem, elem_):
info.graph_.add_to_collection(collection_name_, elem_)
-def transform_optional_op_if_inside_handler(info, op):
+def transform_op_if_inside_handler(info, op, keep_if_possible=True):
"""Transform an optional op only if it is inside the subgraph.
This handler is typically use to handle original op: it is fine to keep them
@@ -92,14 +92,19 @@ def transform_optional_op_if_inside_handler(info, op):
Args:
info: Transform._Info instance.
op: the optional op to transform (or ignore).
+ keep_if_possible: re-attach to the original op if possible, that is,
+ if the source graph and the destination graph are the same.
Returns:
- the transformed op or None if ignored.
+ the transformed op or None.
"""
if op is None: return None
if op in info.sgv.ops:
return info.transformer._transform_op(op) # pylint: disable=protected-access
else:
- return None
+ if keep_if_possible and info.graph is info.graph_:
+ return op
+ else:
+ return None
def copy_op_handler(info, op):
@@ -113,18 +118,22 @@ def copy_op_handler(info, op):
"""
# pylint: disable=protected-access
- # If it has control inputs, call this function recursively on each.
- control_inputs_ = [info.transformer._transform_op(control_input)
- for control_input in op.control_inputs]
+ # Transform control inputs:
+ control_inputs_ = [info.transformer.transform_control_input_handler(info, ci)
+ for ci in op.control_inputs]
+ control_inputs_ = [ci for ci in control_inputs_ if ci is not None]
- # If it has an original_op parameter, copy it
+ # Transform it if any:
original_op_ = info.transformer.transform_original_op_hanlder(info,
op._original_op)
- # If it has inputs, call this function recursively on each.
+ # Transform inputs:
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
+ # Clone the node def:
node_def_ = deepcopy(op._node_def)
+
+ # Transform name:
name_ = info.transformer.new_name(op.name)
name_ = info.graph_.unique_name(name_)
node_def_.name = name_
@@ -182,7 +191,7 @@ class Transformer(object):
class _Info(object):
"""Transformer temporary data.
- An instance of this class hold all the information relevant to a call
+ An instance of this class holds all the information relevant to a call
to a transformer instance (that is, a call to __call__). An instance
is created for the life-time of the __call__ function and is passed as
argument to the handlers.
@@ -200,6 +209,10 @@ class Transformer(object):
self.transformed_ops = {}
self.transformed_ts = {}
+ def create_ops_mapping(self):
+ """Return the mapping from original ops to transformed ops."""
+ return {op.name: op_.name for op, op_ in iteritems(self.transformed_ops)}
+
def __init__(self):
"""Transformer constructor.
@@ -209,13 +222,14 @@ class Transformer(object):
assign_collections_handler: handle the assignment of collections.
This handler defaults to assigning new collections created under the
given name-scope.
- transform_input_handler: handle the transform of the inputs to the given
- subgraph. This handler defaults to creating placeholders instead of the
- ops just before the input tensors of the subgraph.
- transform_hidden_input_handler: handle the transform of the hidden inputs of
- the subgraph, that is, the inputs which are not listed in sgv.inputs.
- This handler defaults to a transform which keep the same input if the
- source and destination graphs are the same, otherwise use placeholders.
+ transform_external_input_handler: handle the transform of the inputs to
+ the given subgraph. This handler defaults to creating placeholders
+ instead of the ops just before the input tensors of the subgraph.
+ transform_external_hidden_input_handler: handle the transform of the
+ hidden inputs of the subgraph, that is, the inputs which are not listed
+ in sgv.inputs. This handler defaults to a transform which keep the same
+ input if the source and destination graphs are the same, otherwise
+ use placeholders.
transform_original_op_hanlder: handle the transform of original_op. This
handler defaults to transforming original_op only if they are in the
subgraph, otherwise they are ignored.
@@ -223,15 +237,17 @@ class Transformer(object):
# handlers
self.transform_op_handler = copy_op_handler
+ self.transform_control_input_handler = transform_op_if_inside_handler
self.assign_collections_handler = assign_renamed_collections_handler
- self.transform_input_handler = transform_tensor_into_placeholder_handler
- self.transform_hidden_input_handler = keep_same_tensor_if_possible_handler
- self.transform_original_op_hanlder = transform_optional_op_if_inside_handler
+ self.transform_external_input_handler = replace_t_with_placeholder_handler
+ self.transform_external_hidden_input_handler = keep_t_if_possible_handler
+ self.transform_original_op_hanlder = transform_op_if_inside_handler
# temporary per-call variable
self._info = None
- def __call__(self, sgv, dst_graph, dst_scope, src_scope=""):
+ def __call__(self, sgv, dst_graph, dst_scope, src_scope="",
+ reuse_dst_scope=False):
"""Execute the transformation.
Args:
@@ -242,16 +258,27 @@ class Transformer(object):
relative path of the transformed nodes are computed. For instance, if
src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a
relative path of x/y and will be transformed into b/x/y.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by postfixing an underscore followed by a digit (default).
Returns:
The transformed subgraph view.
+ A dictionary mapping the name of the original ops to the name of the
+ transformed ops.
Raises:
- ValueError: if the argumens are invalid. For instance, if the source and
- destination are the same.
+ ValueError: if the argumens are invalid.
"""
+ sgv = subgraph.make_view(sgv)
+ if not isinstance(dst_graph, tf_ops.Graph):
+ raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
+
src_scope = util.scope_finalize(src_scope)
dst_scope = util.scope_finalize(dst_scope)
- sgv = subgraph.make_view(sgv)
+ # Potentially create new scope if reuse_dst_scope is False
+ if dst_scope and not reuse_dst_scope:
+ dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1]))
+
if sgv.graph is dst_graph and not dst_scope:
logging.warning("The source and the destination are the same! "
"Beware: in-place transormation are currently "
@@ -264,8 +291,9 @@ class Transformer(object):
sgv_ = self._transform_sgv(sgv)
+ ops_mapping = self._info.create_ops_mapping()
self._info = None
- return sgv_
+ return sgv_, ops_mapping
def _transform_sgv(self, sgv):
"""Transform a subgraph view.
@@ -319,11 +347,16 @@ class Transformer(object):
return self._info.transformed_ts[t]
op, op_index = t.op, t.value_index
+
+ # If op is not in the subgraph:
if op not in self._info.ops:
- if t in self._info.sgv.inputs: # t_ is an input of the subgraph
- t_ = self.transform_input_handler(self._info, t)
- else: # t_ is a hidden input of the subgraph
- t_ = self.transform_hidden_input_handler(self._info, t)
+ # t_ is an input of the subgraph
+ if t in self._info.sgv.inputs:
+ t_ = self.transform_external_input_handler(self._info, t)
+ # t_ is a hidden input of the subgraph
+ else:
+ t_ = self.transform_external_hidden_input_handler(self._info, t)
+ # If op is in the subgraph, just transform it:
else:
op_ = self._transform_op(op)
t_ = op_.outputs[op_index]
@@ -384,7 +417,8 @@ class Transformer(object):
return name_
-def copy(sgv, dst_graph=None, dst_scope="", src_scope=""):
+def copy(sgv, dst_graph=None, dst_scope="", src_scope="",
+ reuse_dst_scope=False):
"""Copy a subgraph.
Args:
@@ -393,6 +427,9 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope=""):
dst_graph: the destination graph.
dst_scope: the destination scope.
src_scope: the source scope.
+ reuse_dst_scope: if True the dst_scope is re-used if it already exists.
+ Otherwise, the scope is given a unique name based on the one given
+ by postfixing an underscore followed by a digit (default).
Returns:
the subgraph view of the copied subgraph.
Raises:
@@ -407,4 +444,5 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope=""):
raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph)))
copier = Transformer()
- return copier(sgv, dst_graph, dst_scope, src_scope)
+ return copier(sgv, dst_graph, dst_scope, src_scope,
+ reuse_dst_scope=reuse_dst_scope)
diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py
index 4fa6cbf374..7d1caa572c 100644
--- a/tensorflow/contrib/graph_editor/util.py
+++ b/tensorflow/contrib/graph_editor/util.py
@@ -58,6 +58,11 @@ class ListView(object):
def __getitem__(self, i):
return self._list[i]
+ def __add__(self, other):
+ if not isinstance(other, list):
+ other = list(other)
+ return list(self) + other
+
# TODO(fkp): very generic code, it should be moved in a more generic place.
def is_iterable(obj):
@@ -107,10 +112,13 @@ def get_unique_graph(tops, check_types=None, none_if_empty=False):
raise TypeError("{} is not iterable".format(type(tops)))
if check_types is None:
check_types = (tf_ops.Operation, tf_ops.Tensor)
+ elif not is_iterable(check_types):
+ check_types = (check_types,)
g = None
for op in tops:
if not isinstance(op, check_types):
- raise TypeError("Expected a tf.Operation, got: {}".format(type(op)))
+ raise TypeError("Expected a type in ({}), got: {}".format(
+ ", ".join([str(t) for t in check_types]), type(op)))
if g is None:
g = op.graph
elif g is not op.graph: