diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/common.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/common.py | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index bf648e158e..b27117dd48 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix): return s[len(prefix):] else: return s + + +def RerouteTensor(t0, t1, can_modify=None): + """Reroute the end of the tensor t0 to the ends of the tensor t1. + + Args: + t0: a tf.Tensor. + t1: a tf.Tensor. + can_modify: iterable of operations which can be modified. Any operation + outside within_ops will be left untouched by this function. + + Returns: + The number of individual modifications made by the function. + """ + nb_update_inputs = 0 + consumers = t1.consumers() + if can_modify is not None: + consumers = [c for c in consumers if c in can_modify] + consumers_indices = {} + for c in consumers: + consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1] + for c in consumers: + for i in consumers_indices[c]: + c._update_input(i, t0) # pylint: disable=protected-access + nb_update_inputs += 1 + return nb_update_inputs |