aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/quantize/python/quant_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/quantize/python/quant_ops.py')
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py57
1 files changed, 37 insertions, 20 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index f80d427ff0..0a38ef9fcd 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -22,12 +22,15 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope
from tensorflow.contrib.framework.python.ops import model_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
+EPSILON = 1e-5
+
@add_arg_scope
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
@@ -130,10 +133,12 @@ def LastValueQuantize(inputs,
batch_min = inputs
else:
batch_min = math_ops.reduce_min(inputs, name='BatchMin')
- # TFLite requires that 0.0 if always in the [min; max] range.
+ batch_min -= EPSILON
+ # B-eng requires that 0.0 if always in the [min; max] range.
batch_min = math_ops.minimum(batch_min, 0.0)
- assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast')
- ops.add_to_collection(updates_collection, assign_min.op)
+ assign_min_op = state_ops.assign(
+ min_var, batch_min, name='AssignMinLast').op
+ ops.add_to_collection(updates_collection, assign_min_op)
if per_channel:
if input_dim >= 2:
@@ -143,15 +148,17 @@ def LastValueQuantize(inputs,
batch_max = inputs
else:
batch_max = math_ops.reduce_max(inputs, name='BatchMax')
- # TFLite requires that 0.0 if always in the [min; max] range.
+ batch_max += EPSILON
+ # B-eng requires that 0.0 if always in the [min; max] range.
batch_max = math_ops.maximum(batch_max, 0.0)
- assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast')
- ops.add_to_collection(updates_collection, assign_max.op)
+ assign_max_op = state_ops.assign(
+ max_var, batch_max, name='AssignMaxLast').op
+ ops.add_to_collection(updates_collection, assign_max_op)
return _FakeQuantWithMinMaxVars(
inputs,
- assign_min,
- assign_max,
+ batch_min,
+ batch_max,
per_channel=per_channel,
num_bits=num_bits,
narrow_range=narrow_range)
@@ -244,9 +251,9 @@ def MovingAvgQuantize(inputs,
batch_min = math_ops.reduce_min(inputs, name='BatchMin')
# B-eng requires that 0.0 if always in the [min; max] range.
batch_min = math_ops.minimum(batch_min, 0.0)
- assign_min = moving_averages.assign_moving_average(
- min_var, batch_min, ema_decay, name='AssignMinEma')
- ops.add_to_collection(updates_collection, assign_min.op)
+ assign_min_op = moving_averages.assign_moving_average(
+ min_var, batch_min, ema_decay, name='AssignMinEma').op
+ ops.add_to_collection(updates_collection, assign_min_op)
if per_channel:
if input_dim >= 2:
@@ -258,14 +265,14 @@ def MovingAvgQuantize(inputs,
batch_max = math_ops.reduce_max(inputs, name='BatchMax')
# B-eng requires that 0.0 if always in the [min; max] range.
batch_max = math_ops.maximum(batch_max, 0.0)
- assign_max = moving_averages.assign_moving_average(
- max_var, batch_max, ema_decay, name='AssignMaxEma')
- ops.add_to_collection(updates_collection, assign_max.op)
+ assign_max_op = moving_averages.assign_moving_average(
+ max_var, batch_max, ema_decay, name='AssignMaxEma').op
+ ops.add_to_collection(updates_collection, assign_max_op)
return _FakeQuantWithMinMaxVars(
inputs,
- assign_min,
- assign_max,
+ min_var,
+ max_var,
per_channel=per_channel,
num_bits=num_bits,
narrow_range=narrow_range)
@@ -294,10 +301,20 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits,
if per_channel:
assert len(min_var.get_shape()) == 1
assert len(max_var.get_shape()) == 1
- return array_ops.fake_quant_with_min_max_vars_per_channel(
- inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
+ with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
+ return array_ops.fake_quant_with_min_max_vars_per_channel(
+ inputs,
+ min_var,
+ max_var,
+ num_bits=num_bits,
+ narrow_range=narrow_range)
else:
assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison
assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison
- return array_ops.fake_quant_with_min_max_vars(
- inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range)
+ with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]):
+ return array_ops.fake_quant_with_min_max_vars(
+ inputs,
+ min_var,
+ max_var,
+ num_bits=num_bits,
+ narrow_range=narrow_range)