diff options
author | Suharsh Sivakumar <suharshs@google.com> | 2018-09-07 13:06:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-07 13:10:10 -0700 |
commit | 93aacda3051d686fffd694c74c98e2eb63bb2261 (patch) | |
tree | 5ffc51c2b5aecb02be14163744dea53f17fc026c /tensorflow/contrib/quantize | |
parent | eb71a1a3afbbe21407b2149d7adc4efa9e557b24 (diff) |
Remove dependency on graph_editor.
PiperOrigin-RevId: 212023248
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r-- | tensorflow/contrib/quantize/BUILD | 3 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/common.py | 26 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/common_test.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/fold_batch_norms.py | 25 | ||||
-rw-r--r-- | tensorflow/contrib/quantize/python/quantize.py | 5 |
5 files changed, 66 insertions, 18 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD index 499fec4ffa..c59f667f6a 100644 --- a/tensorflow/contrib/quantize/BUILD +++ b/tensorflow/contrib/quantize/BUILD @@ -22,6 +22,7 @@ py_test( ":common", "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", "//tensorflow/python:session", "//tensorflow/python:variable_scope", @@ -89,7 +90,6 @@ py_library( ":common", ":graph_matcher", ":input_to_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/core:protos_all_py", "//tensorflow/python:array_ops", "//tensorflow/python:dtypes", @@ -171,7 +171,6 @@ py_library( ":graph_matcher", ":input_to_ops", ":quant_ops", - "//tensorflow/contrib/graph_editor:graph_editor_py", "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_ops", "//tensorflow/python:math_ops", diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py index bf648e158e..b27117dd48 100644 --- a/tensorflow/contrib/quantize/python/common.py +++ b/tensorflow/contrib/quantize/python/common.py @@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix): return s[len(prefix):] else: return s + + +def RerouteTensor(t0, t1, can_modify=None): + """Reroute the end of the tensor t0 to the ends of the tensor t1. + + Args: + t0: a tf.Tensor. + t1: a tf.Tensor. + can_modify: iterable of operations which can be modified. Any operation + outside within_ops will be left untouched by this function. + + Returns: + The number of individual modifications made by the function. + """ + nb_update_inputs = 0 + consumers = t1.consumers() + if can_modify is not None: + consumers = [c for c in consumers if c in can_modify] + consumers_indices = {} + for c in consumers: + consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1] + for c in consumers: + for i in consumers_indices[c]: + c._update_input(i, t0) # pylint: disable=protected-access + nb_update_inputs += 1 + return nb_update_inputs diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py index 06c62f2d26..2b26302f8a 100644 --- a/tensorflow/contrib/quantize/python/common_test.py +++ b/tensorflow/contrib/quantize/python/common_test.py @@ -20,8 +20,10 @@ from __future__ import print_function from tensorflow.contrib.quantize.python import common from tensorflow.python.client import session +from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.framework import test_util +from tensorflow.python.ops import math_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import googletest @@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase): _, step_val = sess.run([b, quantization_step_tensor]) self.assertEqual(step_val, 2) + def testRerouteTensor(self): + a = constant_op.constant(1, name='a') + b = constant_op.constant(2, name='b') + c = constant_op.constant(3, name='c') + d = constant_op.constant(4, name='d') + + add_ac = math_ops.add(a, c) + add_ad = math_ops.add(a, d) + + # Ensure that before rerouting the inputs are what we think. + self._CheckOpHasInputs(add_ac.op, [a, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + # references to tensor a should be replaced with b for all ops in + # can_modify. This means add_ac will be changed but add_ad will not. + common.RerouteTensor(b, a, can_modify=[add_ac.op]) + self._CheckOpHasInputs(add_ac.op, [b, c]) + self._CheckOpHasInputs(add_ad.op, [a, d]) + + def _CheckOpHasInputs(self, op, inputs): + for i in inputs: + self.assertIn(i, op.inputs) + if __name__ == '__main__': googletest.main() diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py index d9f179bee4..2971b28f45 100644 --- a/tensorflow/contrib/quantize/python/fold_batch_norms.py +++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay): bias_add_tensor = math_ops.add( new_layer_tensor, bias_tensor, name='add_fold') - nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor, - match.output_tensor) + nodes_modified_count = common.RerouteTensor(bias_add_tensor, + match.output_tensor) if nodes_modified_count == 0: raise ValueError('Folding batch norms failed, %s had no outputs.' % match.output_tensor.name) @@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: match.bn_decay_mean_tensor, name='freeze_moving_mean') - graph_editor.reroute_ts( - [bn_decay_mean_out], [match.bn_decay_mean_tensor], + common.RerouteTensor( + bn_decay_mean_out, + match.bn_decay_mean_tensor, can_modify=bn_decay_mean_consumers) bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers()) @@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay, lambda: bn_decay_zero, lambda: match.bn_decay_var_tensor, name='freeze_moving_var') - graph_editor.reroute_ts( - [bn_decay_var_out], [match.bn_decay_var_tensor], + common.RerouteTensor( + bn_decay_var_out, + match.bn_decay_var_tensor, can_modify=bn_decay_var_consumers) correction_recip = utils.smart_cond( @@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): activation = common.GetEndpointActivationOp(graph, bn) if activation: - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[activation]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[activation]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % activation.name) continue @@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay): # operations instead of Relu* above. add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1) add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add') - nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]], - [original_op.outputs[0]], - can_modify=[add_bypass]) + nodes_modified_count = common.RerouteTensor( + folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass]) if nodes_modified_count != 1: raise ValueError('Unexpected inputs to op: %s' % add_bypass.name) diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py index 2ddbd73ea6..e88db0acd5 100644 --- a/tensorflow/contrib/quantize/python/quantize.py +++ b/tensorflow/contrib/quantize/python/quantize.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import re -from tensorflow.contrib import graph_editor from tensorflow.contrib.quantize.python import common from tensorflow.contrib.quantize.python import graph_matcher from tensorflow.contrib.quantize.python import input_to_ops @@ -592,8 +591,8 @@ def _InsertQuantOp(context, name=name_prefix + '/delayed_quant') if consumers: - tensors_modified_count = graph_editor.reroute_ts( - [quant], [inputs], can_modify=consumers) + tensors_modified_count = common.RerouteTensor( + quant, inputs, can_modify=consumers) # Some operations can have multiple output tensors going to the same # consumer. Since consumers is a set, we need to ensure that # tensors_modified_count is greater than or equal to the length of the set |