aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-07-13 13:46:46 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-13 15:02:08 -0700
commite24388242026245244435235ea66fd3693942c67 (patch)
tree131d3f98f11b6074ec11c74b870c66c081f8ce02 /tensorflow/contrib/quantization
parentc9b8301ca7632d5a7a3a565e686569a5cc0635e4 (diff)
Fixes a bug in TensorFlow quantization, where a negative constant node
had its min and max values erroneously calculated, causing an exception to be raised. Change: 127362133
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph.py4
-rw-r--r--tensorflow/contrib/quantization/tools/quantize_graph_test.py8
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py
index 3ed2ee07f7..d999797f81 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph.py
@@ -243,8 +243,10 @@ def quantize_weight_eightbit(input_node, quantization_mode):
if min_value == max_value:
if abs(min_value) < 0.000001:
max_value = min_value + 1.0
- else:
+ elif min_value > 0:
max_value = 2 * min_value
+ else:
+ max_value = min_value / 2.0
sess = tf.Session()
with sess.as_default():
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
index 428fafcccb..df3eac5f2c 100644
--- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py
+++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py
@@ -194,6 +194,14 @@ def test_graph(float_graph_def, input_map, output_names):
class QuantizeGraphTest(tf.test.TestCase):
+ def test_negative_const_problem(self):
+ shape_constant_name = "shape_constant"
+ shape_constant = quantize_graph.create_constant_node(
+ shape_constant_name, value=-0.8, dtype=tf.float32, shape=[1])
+ quantization_result = quantize_graph.quantize_weight_eightbit(
+ shape_constant, b"MIN_COMBINED")
+ self.assertEqual(4, len(quantization_result))
+
def test_odd_padding_problem(self):
"""Tests one error case we ran into in a real graph."""
test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME",