diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-08-08 01:05:06 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-08 02:17:22 -0700 |
commit | 5e2019571a0c09bd627dd4f0abbf6b08e19a65e5 (patch) | |
tree | 3db31d03a2ecb20e04e232e9d2e4df292a1aacb1 /tensorflow/contrib/graph_editor | |
parent | 50c8d6c9803c05234ad98355ef11d8cf98a5e7ae (diff) |
Improved documentation and minor bug fixes.
Change: 129607856
Diffstat (limited to 'tensorflow/contrib/graph_editor')
-rw-r--r-- | tensorflow/contrib/graph_editor/README.md | 15 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/__init__.py | 101 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/edit.py | 9 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/select.py | 3 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/subgraph.py | 60 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 104 | ||||
-rw-r--r-- | tensorflow/contrib/graph_editor/util.py | 10 |
7 files changed, 226 insertions, 76 deletions
diff --git a/tensorflow/contrib/graph_editor/README.md b/tensorflow/contrib/graph_editor/README.md index e237208752..f6f82668ed 100644 --- a/tensorflow/contrib/graph_editor/README.md +++ b/tensorflow/contrib/graph_editor/README.md @@ -1,15 +1,6 @@ # TensorFlow Graph Editor -The TensorFlow Graph Editor libray which allows for modification of an existing -tf.Graph instance. +The TensorFlow Graph Editor library allows for modification of an existing +tf.Graph instance in-place. -## Overview of the modules - -* util.py: utility functions -* select.py: selection functions, allowing for various selection method of - tensors and operations. -* subgraph.py: the SubGraphView class, allowing to manipulate subgraph of a - TensorFlow tf.Graph instance. -* transform.py: the Transformer class, allowing to tranform a subgraph into - another one. -* edit.py: various editing function operating on subgraph. +The author's github username is [purpledog](https://github.com/purpledog). diff --git a/tensorflow/contrib/graph_editor/__init__.py b/tensorflow/contrib/graph_editor/__init__.py index 0cbfb5420e..7f20bd88dd 100644 --- a/tensorflow/contrib/graph_editor/__init__.py +++ b/tensorflow/contrib/graph_editor/__init__.py @@ -12,7 +12,103 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Graph editor module allows to modify an existing graph in place. +"""# TensorFlow Graph Editor. + +The TensorFlow Graph Editor library allows for modification of an existing +tf.Graph instance in-place. + +The author's github username is [purpledog](https://github.com/purpledog). + +## Library overview + +Appending new nodes is the only graph editing operation allowed by the +TensorFlow core library. The Graph Editor library is an attempt to allow for +other kinds of editing operations, namely, *rerouting* and *transforming*. + +* *rerouting* is a local operation consisting in re-plugging existing tensors + (the edges of the graph). Operations (the nodes) are not modified by this + operation. For example, rerouting can be used to insert an operation adding + noise in place of an existing tensor. +* *transforming* is a global operation consisting in transforming a graph into + another. By default, a transformation is a simple copy but it can be + customized to achieved other goals. For instance, a graph can be transformed + into another one in which noise is added after all the operations of a + specific type. + +**Important: modifying a graph in-place with the Graph Editor must be done +`offline`, that is, without any active sessions.** + +Of course new operations can be appended online but Graph Editor specific +operations like rerouting and transforming can currently only be done offline. + +Here is an example of what you **cannot** do: + +* Build a graph. +* Create a session and run the graph. +* Modify the graph with the Graph Editor. +* Re-run the graph with the `same` previously created session. + +To edit an already running graph, follow these steps: + +* Build a graph. +* Create a session and run the graph. +* Save the graph state and terminate the session +* Modify the graph with the Graph Editor. +* create a new session and restore the graph state +* Re-run the graph with the newly created session. + +Note that this procedure is very costly because a new session must be created +after any modifications. Among other things, it takes time because the entire +graph state must be saved and restored again. + +### Sub-graph + +Most of the functions in the Graph Editor library operate on *sub-graph*. +More precisely, they take as input arguments instances of the SubGraphView class +(or anything which can be converted to it). Doing so allows the same function +to transparently operate on single operations as well as sub-graph of any size. + +A subgraph can be created in several ways: + +* using a list of ops: + +```python +my_sgv = ge.sgv(ops) +``` + +* from a name scope: + +```python +my_sgv = ge.sgv_scope("foo/bar", graph=tf.get_default_graph()) +``` + +* using regular expression: + +```python +my_sgv = ge.sgv("foo/.*/.*read$", graph=tf.get_default_graph()) +``` + +Note the Graph Editor is meant to manipulate several graphs at the same time, +typically during transform or copy operation. For that reason, +to avoid any confusion, the default graph is never used and the graph on +which to operate must always be explicitely given. This is the reason why +*graph=tf.get_default_graph()* is used in the code snippets above. + + +### Modules + +* util: utility functions. +* select: various selection methods of TensorFlow tensors and operations. +* match: TensorFlow graph matching. Think of this as regular expressions for + graphs (but not quite yet). +* reroute: various ways of rerouting tensors to different consuming ops like + *swap* or *reroute_a2b*. +* subgraph: the SubGraphView class, which enables subgraph manipulations in a + TensorFlow tf.Graph. +* edit: various editing functions operating on subgraphs like *detach*, + *connect* or *bypass*. +* transform: the Transformer class, which enables transforming + (or simply copying) a subgraph into another one. """ from __future__ import absolute_import @@ -54,6 +150,8 @@ from tensorflow.contrib.graph_editor.subgraph import SubGraphView from tensorflow.contrib.graph_editor.transform import copy from tensorflow.contrib.graph_editor.transform import Transformer +from tensorflow.contrib.graph_editor.util import ControlOutputs + # some useful aliases ph = util.make_placeholder_from_dtype_and_shape @@ -62,4 +160,3 @@ sgv_scope = subgraph.make_view_from_scope ts = select.select_ts ops = select.select_ops matcher = match.OpMatcher - diff --git a/tensorflow/contrib/graph_editor/edit.py b/tensorflow/contrib/graph_editor/edit.py index 9d1518a765..b1c41a2af2 100644 --- a/tensorflow/contrib/graph_editor/edit.py +++ b/tensorflow/contrib/graph_editor/edit.py @@ -68,6 +68,7 @@ def detach_inputs(sgv, control_inputs=False): Returns: A new subgraph view of the detached subgraph. Note that sgv is also modified in place. + A list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. @@ -98,6 +99,7 @@ def detach_outputs(sgv, control_outputs=None): Returns: A new subgraph view of the detached subgraph. Note that sgv is also modified in place. + A list of the created output placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. @@ -141,6 +143,8 @@ def detach(sgv, control_inputs=False, control_outputs=None, control_ios=None): Returns: A new subgraph view of the detached subgraph. Note that sgv is also modified in place. + A list of the created input placeholders. + A list of the created output placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. @@ -164,8 +168,8 @@ def connect(sgv0, sgv1, disconnect_first=False): subgraph.make_view. disconnect_first: if True the current outputs of sgv0 are disconnected. Returns: - Two new subgraph views (now connected). sgv0 and svg1 are also modified - in place. + The modified sgv0 (now connected to sgv1). + The modified sgv1 (now connected to sgv0). Raises: StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. @@ -189,6 +193,7 @@ def bypass(sgv): Returns: A new subgraph view of the bypassed subgraph. Note that sgv is also modified in place. + A list of the created input placeholders. Raises: StandardError: if sgv cannot be converted to a SubGraphView using the same rules than the function subgraph.make_view. diff --git a/tensorflow/contrib/graph_editor/select.py b/tensorflow/contrib/graph_editor/select.py index f673dcec35..5f8a50a7d8 100644 --- a/tensorflow/contrib/graph_editor/select.py +++ b/tensorflow/contrib/graph_editor/select.py @@ -20,7 +20,6 @@ from __future__ import division from __future__ import print_function import re -import types from six import iteritems from six import string_types @@ -299,7 +298,7 @@ def compute_boundary_ts(ops, ambiguous_are_outputs=True): def get_within_boundary_ops(ops, seed_ops, - boundary_ops, + boundary_ops=(), inclusive=True, control_inputs=False, control_outputs=None, diff --git a/tensorflow/contrib/graph_editor/subgraph.py b/tensorflow/contrib/graph_editor/subgraph.py index d8dab07427..ac28a114af 100644 --- a/tensorflow/contrib/graph_editor/subgraph.py +++ b/tensorflow/contrib/graph_editor/subgraph.py @@ -164,25 +164,30 @@ class SubGraphView(object): TypeError: if inside_ops cannot be converted to a list of tf.Operation or if passthrough_ts cannot be converted to a list of tf.Tensor. """ + inside_ops = util.make_list_of_op(inside_ops) passthrough_ts = util.make_list_of_t(passthrough_ts) ops_and_ts = inside_ops + passthrough_ts if ops_and_ts: self._graph = util.get_unique_graph(ops_and_ts) - else: - self._graph = None - self._ops = inside_ops + self._ops = inside_ops - # Compute inside and outside tensor - inputs, outputs, insides = select.compute_boundary_ts(inside_ops) + # Compute inside and outside tensor + inputs, outputs, insides = select.compute_boundary_ts(inside_ops) - # Compute passthrough tensors, silently ignoring the non-passthrough ones. - all_tensors = frozenset(inputs + outputs + list(insides)) - self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] + # Compute passthrough tensors, silently ignoring the non-passthrough ones. + all_tensors = frozenset(inputs + outputs + list(insides)) + self._passthrough_ts = [t for t in passthrough_ts if t not in all_tensors] - # Set inputs and outputs. - self._input_ts = inputs + self._passthrough_ts - self._output_ts = outputs + self._passthrough_ts + # Set inputs and outputs. + self._input_ts = inputs + self._passthrough_ts + self._output_ts = outputs + self._passthrough_ts + else: + self._graph = None + self._passthrough_ts = [] + self._input_ts = [] + self._output_ts = [] + self._ops = [] def __copy__(self): """Create a copy of this subgraph. @@ -412,23 +417,27 @@ class SubGraphView(object): raise AssertionError("More than 1 op named: {}!".format(op_name)) return res[0] - def __getitem__(self, op_name): - return self.find_op_by_name(op_name) - def __str__(self): - res = StringIO() + if not self: + return "SubGraphView: empty" + def op_name(op): + return op.name def tensor_name(t): if t in self._passthrough_ts: return "{} *".format(t.name) else: return t.name - print("SubGraphView:", file=res) - print("** ops:", file=res) - print("\n".join([op.name for op in self._ops]), file=res) - print("** inputs:", file=res) - print("\n".join([tensor_name(t) for t in self._input_ts]), file=res) - print("** outputs:", file=res) - print("\n".join([tensor_name(t) for t in self._output_ts]), file=res) + def print_list(name, iterable, get_name): + if iterable: + print("** {}[{}]:".format(name, len(iterable)), file=res) + print("\n".join([get_name(elem) for elem in iterable]), file=res) + else: + print("** {}: empty".format(name), file=res) + res = StringIO() + print("SubGraphView (graphid={}):".format(id(self.graph)), file=res) + print_list("ops", self._ops, op_name) + print_list("inputs", self._input_ts, tensor_name) + print_list("outputs", self._output_ts, tensor_name) return res.getvalue() @property @@ -466,10 +475,13 @@ class SubGraphView(object): """The passthrough tensors, going straight from input to output.""" return util.ListView(self._passthrough_ts) - def __nonzero__(self): + def __bool__(self): """Allows for implicit boolean conversion.""" return self._graph is not None + # Python 3 wants __bool__, Python 2.7 wants __nonzero__ + __nonzero__ = __bool__ + def op(self, op_id): """Get an op by its index.""" return self._ops[op_id] @@ -556,7 +568,7 @@ def _check_graph(sgv, graph): """ if not isinstance(sgv, SubGraphView): raise TypeError("Expected a SubGraphView, got: {}".format(type(graph))) - if graph is None or sgv.graph is None: + if graph is None or not sgv.graph: return sgv if not isinstance(graph, tf_ops.Graph): raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 8214c19674..65e49a3f71 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -31,7 +31,7 @@ from tensorflow.python.framework import ops as tf_ops from tensorflow.python.platform import tf_logging as logging -def transform_tensor_into_placeholder_handler(info, t): +def replace_t_with_placeholder_handler(info, t): """Transform a tensor into a placeholder tensor. This handler is typically used to transform a subgraph input tensor into a @@ -48,7 +48,7 @@ def transform_tensor_into_placeholder_handler(info, t): return t_ -def keep_same_tensor_if_possible_handler(info, t): +def keep_t_if_possible_handler(info, t): """Transform a tensor into itself (identity) if possible. This handler transform a tensor into itself if the source and destination @@ -64,7 +64,7 @@ def keep_same_tensor_if_possible_handler(info, t): if info.graph is info.graph_: return t else: - return transform_tensor_into_placeholder_handler(info, t) + return replace_t_with_placeholder_handler(info, t) def assign_renamed_collections_handler(info, elem, elem_): @@ -83,7 +83,7 @@ def assign_renamed_collections_handler(info, elem, elem_): info.graph_.add_to_collection(collection_name_, elem_) -def transform_optional_op_if_inside_handler(info, op): +def transform_op_if_inside_handler(info, op, keep_if_possible=True): """Transform an optional op only if it is inside the subgraph. This handler is typically use to handle original op: it is fine to keep them @@ -92,14 +92,19 @@ def transform_optional_op_if_inside_handler(info, op): Args: info: Transform._Info instance. op: the optional op to transform (or ignore). + keep_if_possible: re-attach to the original op if possible, that is, + if the source graph and the destination graph are the same. Returns: - the transformed op or None if ignored. + the transformed op or None. """ if op is None: return None if op in info.sgv.ops: return info.transformer._transform_op(op) # pylint: disable=protected-access else: - return None + if keep_if_possible and info.graph is info.graph_: + return op + else: + return None def copy_op_handler(info, op): @@ -113,18 +118,22 @@ def copy_op_handler(info, op): """ # pylint: disable=protected-access - # If it has control inputs, call this function recursively on each. - control_inputs_ = [info.transformer._transform_op(control_input) - for control_input in op.control_inputs] + # Transform control inputs: + control_inputs_ = [info.transformer.transform_control_input_handler(info, ci) + for ci in op.control_inputs] + control_inputs_ = [ci for ci in control_inputs_ if ci is not None] - # If it has an original_op parameter, copy it + # Transform it if any: original_op_ = info.transformer.transform_original_op_hanlder(info, op._original_op) - # If it has inputs, call this function recursively on each. + # Transform inputs: inputs_ = [info.transformer._transform_t(t) for t in op.inputs] + # Clone the node def: node_def_ = deepcopy(op._node_def) + + # Transform name: name_ = info.transformer.new_name(op.name) name_ = info.graph_.unique_name(name_) node_def_.name = name_ @@ -182,7 +191,7 @@ class Transformer(object): class _Info(object): """Transformer temporary data. - An instance of this class hold all the information relevant to a call + An instance of this class holds all the information relevant to a call to a transformer instance (that is, a call to __call__). An instance is created for the life-time of the __call__ function and is passed as argument to the handlers. @@ -200,6 +209,10 @@ class Transformer(object): self.transformed_ops = {} self.transformed_ts = {} + def create_ops_mapping(self): + """Return the mapping from original ops to transformed ops.""" + return {op.name: op_.name for op, op_ in iteritems(self.transformed_ops)} + def __init__(self): """Transformer constructor. @@ -209,13 +222,14 @@ class Transformer(object): assign_collections_handler: handle the assignment of collections. This handler defaults to assigning new collections created under the given name-scope. - transform_input_handler: handle the transform of the inputs to the given - subgraph. This handler defaults to creating placeholders instead of the - ops just before the input tensors of the subgraph. - transform_hidden_input_handler: handle the transform of the hidden inputs of - the subgraph, that is, the inputs which are not listed in sgv.inputs. - This handler defaults to a transform which keep the same input if the - source and destination graphs are the same, otherwise use placeholders. + transform_external_input_handler: handle the transform of the inputs to + the given subgraph. This handler defaults to creating placeholders + instead of the ops just before the input tensors of the subgraph. + transform_external_hidden_input_handler: handle the transform of the + hidden inputs of the subgraph, that is, the inputs which are not listed + in sgv.inputs. This handler defaults to a transform which keep the same + input if the source and destination graphs are the same, otherwise + use placeholders. transform_original_op_hanlder: handle the transform of original_op. This handler defaults to transforming original_op only if they are in the subgraph, otherwise they are ignored. @@ -223,15 +237,17 @@ class Transformer(object): # handlers self.transform_op_handler = copy_op_handler + self.transform_control_input_handler = transform_op_if_inside_handler self.assign_collections_handler = assign_renamed_collections_handler - self.transform_input_handler = transform_tensor_into_placeholder_handler - self.transform_hidden_input_handler = keep_same_tensor_if_possible_handler - self.transform_original_op_hanlder = transform_optional_op_if_inside_handler + self.transform_external_input_handler = replace_t_with_placeholder_handler + self.transform_external_hidden_input_handler = keep_t_if_possible_handler + self.transform_original_op_hanlder = transform_op_if_inside_handler # temporary per-call variable self._info = None - def __call__(self, sgv, dst_graph, dst_scope, src_scope=""): + def __call__(self, sgv, dst_graph, dst_scope, src_scope="", + reuse_dst_scope=False): """Execute the transformation. Args: @@ -242,16 +258,27 @@ class Transformer(object): relative path of the transformed nodes are computed. For instance, if src_scope is a/ and dst_scoped is b/, then the node a/x/y will have a relative path of x/y and will be transformed into b/x/y. + reuse_dst_scope: if True the dst_scope is re-used if it already exists. + Otherwise, the scope is given a unique name based on the one given + by postfixing an underscore followed by a digit (default). Returns: The transformed subgraph view. + A dictionary mapping the name of the original ops to the name of the + transformed ops. Raises: - ValueError: if the argumens are invalid. For instance, if the source and - destination are the same. + ValueError: if the argumens are invalid. """ + sgv = subgraph.make_view(sgv) + if not isinstance(dst_graph, tf_ops.Graph): + raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) + src_scope = util.scope_finalize(src_scope) dst_scope = util.scope_finalize(dst_scope) - sgv = subgraph.make_view(sgv) + # Potentially create new scope if reuse_dst_scope is False + if dst_scope and not reuse_dst_scope: + dst_scope = util.scope_finalize(dst_graph.unique_name(dst_scope[:-1])) + if sgv.graph is dst_graph and not dst_scope: logging.warning("The source and the destination are the same! " "Beware: in-place transormation are currently " @@ -264,8 +291,9 @@ class Transformer(object): sgv_ = self._transform_sgv(sgv) + ops_mapping = self._info.create_ops_mapping() self._info = None - return sgv_ + return sgv_, ops_mapping def _transform_sgv(self, sgv): """Transform a subgraph view. @@ -319,11 +347,16 @@ class Transformer(object): return self._info.transformed_ts[t] op, op_index = t.op, t.value_index + + # If op is not in the subgraph: if op not in self._info.ops: - if t in self._info.sgv.inputs: # t_ is an input of the subgraph - t_ = self.transform_input_handler(self._info, t) - else: # t_ is a hidden input of the subgraph - t_ = self.transform_hidden_input_handler(self._info, t) + # t_ is an input of the subgraph + if t in self._info.sgv.inputs: + t_ = self.transform_external_input_handler(self._info, t) + # t_ is a hidden input of the subgraph + else: + t_ = self.transform_external_hidden_input_handler(self._info, t) + # If op is in the subgraph, just transform it: else: op_ = self._transform_op(op) t_ = op_.outputs[op_index] @@ -384,7 +417,8 @@ class Transformer(object): return name_ -def copy(sgv, dst_graph=None, dst_scope="", src_scope=""): +def copy(sgv, dst_graph=None, dst_scope="", src_scope="", + reuse_dst_scope=False): """Copy a subgraph. Args: @@ -393,6 +427,9 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope=""): dst_graph: the destination graph. dst_scope: the destination scope. src_scope: the source scope. + reuse_dst_scope: if True the dst_scope is re-used if it already exists. + Otherwise, the scope is given a unique name based on the one given + by postfixing an underscore followed by a digit (default). Returns: the subgraph view of the copied subgraph. Raises: @@ -407,4 +444,5 @@ def copy(sgv, dst_graph=None, dst_scope="", src_scope=""): raise TypeError("Expected a tf.Graph, got: {}".format(type(dst_graph))) copier = Transformer() - return copier(sgv, dst_graph, dst_scope, src_scope) + return copier(sgv, dst_graph, dst_scope, src_scope, + reuse_dst_scope=reuse_dst_scope) diff --git a/tensorflow/contrib/graph_editor/util.py b/tensorflow/contrib/graph_editor/util.py index 4fa6cbf374..7d1caa572c 100644 --- a/tensorflow/contrib/graph_editor/util.py +++ b/tensorflow/contrib/graph_editor/util.py @@ -58,6 +58,11 @@ class ListView(object): def __getitem__(self, i): return self._list[i] + def __add__(self, other): + if not isinstance(other, list): + other = list(other) + return list(self) + other + # TODO(fkp): very generic code, it should be moved in a more generic place. def is_iterable(obj): @@ -107,10 +112,13 @@ def get_unique_graph(tops, check_types=None, none_if_empty=False): raise TypeError("{} is not iterable".format(type(tops))) if check_types is None: check_types = (tf_ops.Operation, tf_ops.Tensor) + elif not is_iterable(check_types): + check_types = (check_types,) g = None for op in tops: if not isinstance(op, check_types): - raise TypeError("Expected a tf.Operation, got: {}".format(type(op))) + raise TypeError("Expected a type in ({}), got: {}".format( + ", ".join([str(t) for t in check_types]), type(op))) if g is None: g = op.graph elif g is not op.graph: |