diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2016-09-07 08:02:28 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-09-07 09:17:51 -0700 |
commit | 558b6febb8ae66872510429330ef7df76eb03775 (patch) | |
tree | 0218971fb06f070d643edfbfa6c9fa9988c612b2 /tensorflow/contrib/graph_editor/transform.py | |
parent | 9ed3b81c8898cd86a8611c589353dbfb0e8f51d7 (diff) |
graph_replace: return original tensor if not transformed.
Improved general performance to make working with big graphs faster.
Change: 132442208
Diffstat (limited to 'tensorflow/contrib/graph_editor/transform.py')
-rw-r--r-- | tensorflow/contrib/graph_editor/transform.py | 37 |
1 files changed, 26 insertions, 11 deletions
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py index 584318891c..0b8928010b 100644 --- a/tensorflow/contrib/graph_editor/transform.py +++ b/tensorflow/contrib/graph_editor/transform.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function from copy import deepcopy +from functools import partial from six import iteritems from six import iterkeys @@ -266,11 +267,13 @@ class Transformer(object): "Expected a tf.Tensor or a tf.Operation, got a {}".format( type(top))) - def _transformed_elem(self, original_top): + def _transformed_elem(self, original_top, missing_fn=None): """Return the transformed op/tensor corresponding to the original one. Args: original_top: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. Returns: the transformed tensor/operation (or None if no match is found). """ @@ -279,17 +282,19 @@ class Transformer(object): for original, transformed in iteritems(transformed_map): if original.name == original_top: return transformed - return None + return None if missing_fn is None else missing_fn(original_top) else: if original_top not in transformed_map: - return None + return None if missing_fn is None else missing_fn(original_top) return transformed_map[original_top] - def _original_elem(self, transformed_top): + def _original_elem(self, transformed_top, missing_fn=None): """Return the original op/tensor corresponding to the transformed one. Args: transformed_top: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. Returns: the original tensor/operation (or None if no match is found). """ @@ -301,9 +306,9 @@ class Transformer(object): for original, transformed in iteritems(transformed_map): if finder(transformed): return original - return None + return None if missing_fn is None else missing_fn(transformed_top) - def transformed(self, original): + def transformed(self, original, missing_fn=None): """Return the transformed op/tensor corresponding to the original one. Note that the output of this function mimics the hierarchy @@ -313,12 +318,15 @@ class Transformer(object): Args: original: the original tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. Returns: the transformed tensor/operation (or None if no match is found). """ - return util.transform_tree(original, self._transformed_elem) + transformed_elem = partial(self._transformed_elem, missing_fn=missing_fn) + return util.transform_tree(original, transformed_elem) - def original(self, transformed): + def original(self, transformed, missing_fn=None): """Return the original op/tensor corresponding to the transformed one. Note that the output of this function mimics the hierarchy @@ -328,10 +336,13 @@ class Transformer(object): Args: transformed: the transformed tensor/operation. + missing_fn: function handling the case where the counterpart + cannot be found. By default, None is returned. Returns: the original tensor/operation (or None if no match is found). """ - return util.transform_tree(transformed, self._original_elem) + original_elem = partial(self._original_elem, missing_fn=missing_fn) + return util.transform_tree(transformed, original_elem) def __str__(self): res = StringIO() @@ -543,6 +554,8 @@ class Transformer(object): # All to all the active devices for device_function in reversed(self._info.graph_._device_function_stack): + if device_function is None: + break op_._set_device(device_function(op_)) # pylint: enable=protected-access @@ -695,5 +708,7 @@ def graph_replace(target_ts, replacement_ts, dst_scope="", # Create a copy of the relevant subgraph _, info = copy_with_input_replacements( ops, replacement_ts, None, dst_scope, src_scope, reuse_dst_scope) - # Return the transformed targets - return info.transformed(target_ts) + # Return the transformed targets but keep the original if the transformed + # counterpart cannot be found + missing_fn = lambda original_t: original_t + return info.transformed(target_ts, missing_fn) |