aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/graph_editor/transform.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-09-07 08:02:28 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-09-07 09:17:51 -0700
commit558b6febb8ae66872510429330ef7df76eb03775 (patch)
tree0218971fb06f070d643edfbfa6c9fa9988c612b2 /tensorflow/contrib/graph_editor/transform.py
parent9ed3b81c8898cd86a8611c589353dbfb0e8f51d7 (diff)
graph_replace: return original tensor if not transformed.
Improved general performance to make working with big graphs faster. Change: 132442208
Diffstat (limited to 'tensorflow/contrib/graph_editor/transform.py')
-rw-r--r--tensorflow/contrib/graph_editor/transform.py37
1 files changed, 26 insertions, 11 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 584318891c..0b8928010b 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from copy import deepcopy
+from functools import partial
from six import iteritems
from six import iterkeys
@@ -266,11 +267,13 @@ class Transformer(object):
"Expected a tf.Tensor or a tf.Operation, got a {}".format(
type(top)))
- def _transformed_elem(self, original_top):
+ def _transformed_elem(self, original_top, missing_fn=None):
"""Return the transformed op/tensor corresponding to the original one.
Args:
original_top: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
Returns:
the transformed tensor/operation (or None if no match is found).
"""
@@ -279,17 +282,19 @@ class Transformer(object):
for original, transformed in iteritems(transformed_map):
if original.name == original_top:
return transformed
- return None
+ return None if missing_fn is None else missing_fn(original_top)
else:
if original_top not in transformed_map:
- return None
+ return None if missing_fn is None else missing_fn(original_top)
return transformed_map[original_top]
- def _original_elem(self, transformed_top):
+ def _original_elem(self, transformed_top, missing_fn=None):
"""Return the original op/tensor corresponding to the transformed one.
Args:
transformed_top: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
Returns:
the original tensor/operation (or None if no match is found).
"""
@@ -301,9 +306,9 @@ class Transformer(object):
for original, transformed in iteritems(transformed_map):
if finder(transformed):
return original
- return None
+ return None if missing_fn is None else missing_fn(transformed_top)
- def transformed(self, original):
+ def transformed(self, original, missing_fn=None):
"""Return the transformed op/tensor corresponding to the original one.
Note that the output of this function mimics the hierarchy
@@ -313,12 +318,15 @@ class Transformer(object):
Args:
original: the original tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
Returns:
the transformed tensor/operation (or None if no match is found).
"""
- return util.transform_tree(original, self._transformed_elem)
+ transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn)
+ return util.transform_tree(original, transformed_elem)
- def original(self, transformed):
+ def original(self, transformed, missing_fn=None):
"""Return the original op/tensor corresponding to the transformed one.
Note that the output of this function mimics the hierarchy
@@ -328,10 +336,13 @@ class Transformer(object):
Args:
transformed: the transformed tensor/operation.
+ missing_fn: function handling the case where the counterpart
+ cannot be found. By default, None is returned.
Returns:
the original tensor/operation (or None if no match is found).
"""
- return util.transform_tree(transformed, self._original_elem)
+ original_elem = partial(self._original_elem, missing_fn=missing_fn)
+ return util.transform_tree(transformed, original_elem)
def __str__(self):
res = StringIO()
@@ -543,6 +554,8 @@ class Transformer(object):
# All to all the active devices
for device_function in reversed(self._info.graph_._device_function_stack):
+ if device_function is None:
+ break
op_._set_device(device_function(op_))
# pylint: enable=protected-access
@@ -695,5 +708,7 @@ def graph_replace(target_ts, replacement_ts, dst_scope="",
# Create a copy of the relevant subgraph
_, info = copy_with_input_replacements(
ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope)
- # Return the transformed targets
- return info.transformed(target_ts)
+ # Return the transformed targets but keep the original if the transformed
+ # counterpart cannot be found
+ missing_fn = lambda original_t: original_t
+ return info.transformed(target_ts, missing_fn)