diff options
Diffstat (limited to 'tensorflow/contrib/graph_editor/transform.py')
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 37 |
1 files changed, 32 insertions, 5 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 6fb347c834..832698b8a0 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -26,13 +26,13 @@ from six import iteritems from six import iterkeys from six import string_types from six import StringIO - from tensorflow.contrib.graph_editor import edit from tensorflow.contrib.graph_editor import reroute from tensorflow.contrib.graph_editor import select from tensorflow.contrib.graph_editor import subgraph from tensorflow.contrib.graph_editor import util from tensorflow.python.framework import ops as tf_ops +from tensorflow.python.platform import tf_logging as logging __all__ = [ "replace_t_with_placeholder_handler", @@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t): def assign_renamed_collections_handler(info, elem, elem_): """Add the transformed elem to the (renamed) collections of elem. + A collection is renamed only if is not a known key, as described in + `tf.GraphKeys`. + Args: info: Transform._Info instance. elem: the original element (`tf.Tensor` or `tf.Operation`) elem_: the transformed element """ - # TODO(fkp): handle known special cases + known_collection_names = util.get_predefined_collection_names() for name, collection in iteritems(info.collections): if elem not in collection: continue - collection_name_ = info.transformer.new_name(name) - info.graph_.add_to_collection(collection_name_, elem_) + + if name in known_collection_names: + transformed_name = name + else: + transformed_name = info.transformer.new_name(name) + info.graph_.add_to_collection(transformed_name, elem_) def transform_op_if_inside_handler(info, op, keep_if_possible=True): @@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True): # Transform inputs: inputs_ = [info.transformer._transform_t(t) for t in op.inputs] + # Leave inputs empty if a graph cycle was found. + if None in inputs_: + info.cyclic_ops.append(op) + inputs_ = [] + # Clone the node def: node_def_ = deepcopy(op._node_def) @@ -239,7 +251,7 @@ class Transformer(object): self.transformed_ts = {} self.collections = dict((key, self.graph.get_collection(key)) for key in self.graph.get_all_collection_keys()) - + self.cyclic_ops = [] class ResultInfo(object): """"Contains information about the result of a transform operation.""" @@ -452,6 +464,17 @@ class Transformer(object): for op in remaining_roots: self._transform_op(op) + # Finalize cyclic ops: + for op in self._info.cyclic_ops: + logging.debug("Finalizing cyclic op: %s", op.name) + op_ = self._info.transformed_ops[op] + inputs_ = [self._info.transformed_ts[t] for t in op.inputs] + if None in inputs_: + raise ValueError("Could not find all the inputs of cyclic op: {}" + .format(op_.name)) + for input_id, t_ in enumerate(inputs_): + op_._update_input(input_id, t_) # pylint: disable=protected-access + sgv_ = self._transform_sgv(sgv) res_info = Transformer.ResultInfo(self._info) @@ -506,9 +529,13 @@ class Transformer(object): Returns: The transformed tensor. """ + logging.debug("Transforming tensor: %s", t.name) if t in self._info.transformed_ts: return self._info.transformed_ts[t] + # Mark as None to detect cycle. + self._info.transformed_ts[t] = None + op, op_index = t.op, t.value_index # If op is not in the subgraph: |