aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/transform.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/graph_editor/transform.py')
-rw-r--r--tensorflow/contrib/graph_editor/transform.py37
1 files changed, 32 insertions, 5 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 6fb347c834..832698b8a0 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -26,13 +26,13 @@ from six import iteritems
from six import iterkeys
from six import string_types
from six import StringIO
-
from tensorflow.contrib.graph_editor import edit
from tensorflow.contrib.graph_editor import reroute
from tensorflow.contrib.graph_editor import select
from tensorflow.contrib.graph_editor import subgraph
from tensorflow.contrib.graph_editor import util
from tensorflow.python.framework import ops as tf_ops
+from tensorflow.python.platform import tf_logging as logging
__all__ = [
"replace_t_with_placeholder_handler",
@@ -87,17 +87,24 @@ def keep_t_if_possible_handler(info, t):
def assign_renamed_collections_handler(info, elem, elem_):
"""Add the transformed elem to the (renamed) collections of elem.
+ A collection is renamed only if is not a known key, as described in
+ `tf.GraphKeys`.
+
Args:
info: Transform._Info instance.
elem: the original element (`tf.Tensor` or `tf.Operation`)
elem_: the transformed element
"""
- # TODO(fkp): handle known special cases
+ known_collection_names = util.get_predefined_collection_names()
for name, collection in iteritems(info.collections):
if elem not in collection:
continue
- collection_name_ = info.transformer.new_name(name)
- info.graph_.add_to_collection(collection_name_, elem_)
+
+ if name in known_collection_names:
+ transformed_name = name
+ else:
+ transformed_name = info.transformer.new_name(name)
+ info.graph_.add_to_collection(transformed_name, elem_)
def transform_op_if_inside_handler(info, op, keep_if_possible=True):
@@ -150,6 +157,11 @@ def copy_op_handler(info, op, copy_shape=True):
# Transform inputs:
inputs_ = [info.transformer._transform_t(t) for t in op.inputs]
+ # Leave inputs empty if a graph cycle was found.
+ if None in inputs_:
+ info.cyclic_ops.append(op)
+ inputs_ = []
+
# Clone the node def:
node_def_ = deepcopy(op._node_def)
@@ -239,7 +251,7 @@ class Transformer(object):
self.transformed_ts = {}
self.collections = dict((key, self.graph.get_collection(key))
for key in self.graph.get_all_collection_keys())
-
+ self.cyclic_ops = []
class ResultInfo(object):
""""Contains information about the result of a transform operation."""
@@ -452,6 +464,17 @@ class Transformer(object):
for op in remaining_roots:
self._transform_op(op)
+ # Finalize cyclic ops:
+ for op in self._info.cyclic_ops:
+ logging.debug("Finalizing cyclic op: %s", op.name)
+ op_ = self._info.transformed_ops[op]
+ inputs_ = [self._info.transformed_ts[t] for t in op.inputs]
+ if None in inputs_:
+ raise ValueError("Could not find all the inputs of cyclic op: {}"
+ .format(op_.name))
+ for input_id, t_ in enumerate(inputs_):
+ op_._update_input(input_id, t_) # pylint: disable=protected-access
+
sgv_ = self._transform_sgv(sgv)
res_info = Transformer.ResultInfo(self._info)
@@ -506,9 +529,13 @@ class Transformer(object):
Returns:
The transformed tensor.
"""
+ logging.debug("Transforming tensor: %s", t.name)
if t in self._info.transformed_ts:
return self._info.transformed_ts[t]
+ # Mark as None to detect cycle.
+ self._info.transformed_ts[t] = None
+
op, op_index = t.op, t.value_index
# If op is not in the subgraph: