diff options
author | 2016-07-13 13:46:46 -0800 | |
---|---|---|
committer | 2016-07-13 15:02:08 -0700 | |
commit | e24388242026245244435235ea66fd3693942c67 (patch) | |
tree | 131d3f98f11b6074ec11c74b870c66c081f8ce02 /tensorflow/contrib/quantization | |
parent | c9b8301ca7632d5a7a3a565e686569a5cc0635e4 (diff) |
Fixes a bug in TensorFlow quantization, where a negative constant node
had its min and max values erroneously calculated, causing an exception
to be raised.
Change: 127362133
Diffstat (limited to 'tensorflow/contrib/quantization')
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph.py | 4 | ||||
-rw-r--r-- | tensorflow/contrib/quantization/tools/quantize_graph_test.py | 8 |
2 files changed, 11 insertions, 1 deletions
diff --git a/tensorflow/contrib/quantization/tools/quantize_graph.py b/tensorflow/contrib/quantization/tools/quantize_graph.py index 3ed2ee07f7..d999797f81 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph.py @@ -243,8 +243,10 @@ def quantize_weight_eightbit(input_node, quantization_mode): if min_value == max_value: if abs(min_value) < 0.000001: max_value = min_value + 1.0 - else: + elif min_value > 0: max_value = 2 * min_value + else: + max_value = min_value / 2.0 sess = tf.Session() with sess.as_default(): diff --git a/tensorflow/contrib/quantization/tools/quantize_graph_test.py b/tensorflow/contrib/quantization/tools/quantize_graph_test.py index 428fafcccb..df3eac5f2c 100644 --- a/tensorflow/contrib/quantization/tools/quantize_graph_test.py +++ b/tensorflow/contrib/quantization/tools/quantize_graph_test.py @@ -194,6 +194,14 @@ def test_graph(float_graph_def, input_map, output_names): class QuantizeGraphTest(tf.test.TestCase): + def test_negative_const_problem(self): + shape_constant_name = "shape_constant" + shape_constant = quantize_graph.create_constant_node( + shape_constant_name, value=-0.8, dtype=tf.float32, shape=[1]) + quantization_result = quantize_graph.quantize_weight_eightbit( + shape_constant, b"MIN_COMBINED") + self.assertEqual(4, len(quantization_result)) + def test_odd_padding_problem(self): """Tests one error case we ran into in a real graph.""" test_conv(1, 4, 4, 1, 3, 1, 2, b"SAME", |