aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize
diff options
context:
space:
mode:
authorGravatar Suharsh Sivakumar <suharshs@google.com>2018-09-07 13:06:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-07 13:10:10 -0700
commit93aacda3051d686fffd694c74c98e2eb63bb2261 (patch)
tree5ffc51c2b5aecb02be14163744dea53f17fc026c /tensorflow/contrib/quantize
parenteb71a1a3afbbe21407b2149d7adc4efa9e557b24 (diff)
Remove dependency on graph_editor.
PiperOrigin-RevId: 212023248
Diffstat (limited to 'tensorflow/contrib/quantize')
-rw-r--r--tensorflow/contrib/quantize/BUILD3
-rw-r--r--tensorflow/contrib/quantize/python/common.py26
-rw-r--r--tensorflow/contrib/quantize/python/common_test.py25
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py25
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py5
5 files changed, 66 insertions, 18 deletions
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 499fec4ffa..c59f667f6a 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -22,6 +22,7 @@ py_test(
":common",
"//tensorflow/python:framework_ops",
"//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
"//tensorflow/python:variable_scope",
@@ -89,7 +90,6 @@ py_library(
":common",
":graph_matcher",
":input_to_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:dtypes",
@@ -171,7 +171,6 @@ py_library(
":graph_matcher",
":input_to_ops",
":quant_ops",
- "//tensorflow/contrib/graph_editor:graph_editor_py",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
diff --git a/tensorflow/contrib/quantize/python/common.py b/tensorflow/contrib/quantize/python/common.py
index bf648e158e..b27117dd48 100644
--- a/tensorflow/contrib/quantize/python/common.py
+++ b/tensorflow/contrib/quantize/python/common.py
@@ -131,3 +131,29 @@ def DropStringPrefix(s, prefix):
return s[len(prefix):]
else:
return s
+
+
+def RerouteTensor(t0, t1, can_modify=None):
+ """Reroute the end of the tensor t0 to the ends of the tensor t1.
+
+ Args:
+ t0: a tf.Tensor.
+ t1: a tf.Tensor.
+ can_modify: iterable of operations which can be modified. Any operation
+ outside within_ops will be left untouched by this function.
+
+ Returns:
+ The number of individual modifications made by the function.
+ """
+ nb_update_inputs = 0
+ consumers = t1.consumers()
+ if can_modify is not None:
+ consumers = [c for c in consumers if c in can_modify]
+ consumers_indices = {}
+ for c in consumers:
+ consumers_indices[c] = [i for i, t in enumerate(c.inputs) if t is t1]
+ for c in consumers:
+ for i in consumers_indices[c]:
+ c._update_input(i, t0) # pylint: disable=protected-access
+ nb_update_inputs += 1
+ return nb_update_inputs
diff --git a/tensorflow/contrib/quantize/python/common_test.py b/tensorflow/contrib/quantize/python/common_test.py
index 06c62f2d26..2b26302f8a 100644
--- a/tensorflow/contrib/quantize/python/common_test.py
+++ b/tensorflow/contrib/quantize/python/common_test.py
@@ -20,8 +20,10 @@ from __future__ import print_function
from tensorflow.contrib.quantize.python import common
from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -62,6 +64,29 @@ class CommonTest(test_util.TensorFlowTestCase):
_, step_val = sess.run([b, quantization_step_tensor])
self.assertEqual(step_val, 2)
+ def testRerouteTensor(self):
+ a = constant_op.constant(1, name='a')
+ b = constant_op.constant(2, name='b')
+ c = constant_op.constant(3, name='c')
+ d = constant_op.constant(4, name='d')
+
+ add_ac = math_ops.add(a, c)
+ add_ad = math_ops.add(a, d)
+
+ # Ensure that before rerouting the inputs are what we think.
+ self._CheckOpHasInputs(add_ac.op, [a, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ # references to tensor a should be replaced with b for all ops in
+ # can_modify. This means add_ac will be changed but add_ad will not.
+ common.RerouteTensor(b, a, can_modify=[add_ac.op])
+ self._CheckOpHasInputs(add_ac.op, [b, c])
+ self._CheckOpHasInputs(add_ad.op, [a, d])
+
+ def _CheckOpHasInputs(self, op, inputs):
+ for i in inputs:
+ self.assertIn(i, op.inputs)
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index d9f179bee4..2971b28f45 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -134,8 +133,8 @@ def _FoldFusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
bias_add_tensor = math_ops.add(
new_layer_tensor, bias_tensor, name='add_fold')
- nodes_modified_count = graph_editor.reroute_ts(bias_add_tensor,
- match.output_tensor)
+ nodes_modified_count = common.RerouteTensor(bias_add_tensor,
+ match.output_tensor)
if nodes_modified_count == 0:
raise ValueError('Folding batch norms failed, %s had no outputs.' %
match.output_tensor.name)
@@ -370,8 +369,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: match.bn_decay_mean_tensor,
name='freeze_moving_mean')
- graph_editor.reroute_ts(
- [bn_decay_mean_out], [match.bn_decay_mean_tensor],
+ common.RerouteTensor(
+ bn_decay_mean_out,
+ match.bn_decay_mean_tensor,
can_modify=bn_decay_mean_consumers)
bn_decay_var_consumers = list(match.bn_decay_var_tensor.consumers())
@@ -380,8 +380,9 @@ def _ComputeBatchNormCorrections(context, match, freeze_batch_norm_delay,
lambda: bn_decay_zero,
lambda: match.bn_decay_var_tensor,
name='freeze_moving_var')
- graph_editor.reroute_ts(
- [bn_decay_var_out], [match.bn_decay_var_tensor],
+ common.RerouteTensor(
+ bn_decay_var_out,
+ match.bn_decay_var_tensor,
can_modify=bn_decay_var_consumers)
correction_recip = utils.smart_cond(
@@ -486,9 +487,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
activation = common.GetEndpointActivationOp(graph, bn)
if activation:
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[activation])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[activation])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % activation.name)
continue
@@ -497,9 +497,8 @@ def _FoldUnfusedBatchNorms(graph, is_training, freeze_batch_norm_delay):
# operations instead of Relu* above.
add_bypass_ctx = re.search(r'^(.*)/([^/]+)', bn).group(1)
add_bypass = graph.get_operation_by_name(add_bypass_ctx + '/Add')
- nodes_modified_count = graph_editor.reroute_ts([folded_op.outputs[0]],
- [original_op.outputs[0]],
- can_modify=[add_bypass])
+ nodes_modified_count = common.RerouteTensor(
+ folded_op.outputs[0], original_op.outputs[0], can_modify=[add_bypass])
if nodes_modified_count != 1:
raise ValueError('Unexpected inputs to op: %s' % add_bypass.name)
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index 2ddbd73ea6..e88db0acd5 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import re
-from tensorflow.contrib import graph_editor
from tensorflow.contrib.quantize.python import common
from tensorflow.contrib.quantize.python import graph_matcher
from tensorflow.contrib.quantize.python import input_to_ops
@@ -592,8 +591,8 @@ def _InsertQuantOp(context,
name=name_prefix + '/delayed_quant')
if consumers:
- tensors_modified_count = graph_editor.reroute_ts(
- [quant], [inputs], can_modify=consumers)
+ tensors_modified_count = common.RerouteTensor(
+ quant, inputs, can_modify=consumers)
# Some operations can have multiple output tensors going to the same
# consumer. Since consumers is a set, we need to ensure that
# tensors_modified_count is greater than or equal to the length of the set