diff options
Diffstat (limited to 'tensorflow/contrib/quantize/python/quant_ops.py')
-rw-r--r-- | tensorflow/contrib/quantize/python/quant_ops.py | 57 |
1 files changed, 37 insertions, 20 deletions
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py index f80d427ff0..0a38ef9fcd 100644 --- a/tensorflow/contrib/quantize/python/quant_ops.py +++ b/tensorflow/contrib/quantize/python/quant_ops.py @@ -22,12 +22,15 @@ from tensorflow.contrib.framework.python.ops import add_arg_scope from tensorflow.contrib.framework.python.ops import model_variable from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.training import moving_averages +EPSILON = 1e-5 + @add_arg_scope def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None): @@ -130,10 +133,12 @@ def LastValueQuantize(inputs, batch_min = inputs else: batch_min = math_ops.reduce_min(inputs, name='BatchMin') - # TFLite requires that 0.0 if always in the [min; max] range. + batch_min -= EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. batch_min = math_ops.minimum(batch_min, 0.0) - assign_min = state_ops.assign(min_var, batch_min, name='AssignMinLast') - ops.add_to_collection(updates_collection, assign_min.op) + assign_min_op = state_ops.assign( + min_var, batch_min, name='AssignMinLast').op + ops.add_to_collection(updates_collection, assign_min_op) if per_channel: if input_dim >= 2: @@ -143,15 +148,17 @@ def LastValueQuantize(inputs, batch_max = inputs else: batch_max = math_ops.reduce_max(inputs, name='BatchMax') - # TFLite requires that 0.0 if always in the [min; max] range. + batch_max += EPSILON + # B-eng requires that 0.0 if always in the [min; max] range. batch_max = math_ops.maximum(batch_max, 0.0) - assign_max = state_ops.assign(max_var, batch_max, name='AssignMaxLast') - ops.add_to_collection(updates_collection, assign_max.op) + assign_max_op = state_ops.assign( + max_var, batch_max, name='AssignMaxLast').op + ops.add_to_collection(updates_collection, assign_max_op) return _FakeQuantWithMinMaxVars( inputs, - assign_min, - assign_max, + batch_min, + batch_max, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range) @@ -244,9 +251,9 @@ def MovingAvgQuantize(inputs, batch_min = math_ops.reduce_min(inputs, name='BatchMin') # B-eng requires that 0.0 if always in the [min; max] range. batch_min = math_ops.minimum(batch_min, 0.0) - assign_min = moving_averages.assign_moving_average( - min_var, batch_min, ema_decay, name='AssignMinEma') - ops.add_to_collection(updates_collection, assign_min.op) + assign_min_op = moving_averages.assign_moving_average( + min_var, batch_min, ema_decay, name='AssignMinEma').op + ops.add_to_collection(updates_collection, assign_min_op) if per_channel: if input_dim >= 2: @@ -258,14 +265,14 @@ def MovingAvgQuantize(inputs, batch_max = math_ops.reduce_max(inputs, name='BatchMax') # B-eng requires that 0.0 if always in the [min; max] range. batch_max = math_ops.maximum(batch_max, 0.0) - assign_max = moving_averages.assign_moving_average( - max_var, batch_max, ema_decay, name='AssignMaxEma') - ops.add_to_collection(updates_collection, assign_max.op) + assign_max_op = moving_averages.assign_moving_average( + max_var, batch_max, ema_decay, name='AssignMaxEma').op + ops.add_to_collection(updates_collection, assign_max_op) return _FakeQuantWithMinMaxVars( inputs, - assign_min, - assign_max, + min_var, + max_var, per_channel=per_channel, num_bits=num_bits, narrow_range=narrow_range) @@ -294,10 +301,20 @@ def _FakeQuantWithMinMaxVars(inputs, min_var, max_var, per_channel, num_bits, if per_channel: assert len(min_var.get_shape()) == 1 assert len(max_var.get_shape()) == 1 - return array_ops.fake_quant_with_min_max_vars_per_channel( - inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars_per_channel( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) else: assert min_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison assert max_var.get_shape() == [] # pylint: disable=g-explicit-bool-comparison - return array_ops.fake_quant_with_min_max_vars( - inputs, min_var, max_var, num_bits=num_bits, narrow_range=narrow_range) + with ops.control_dependencies([check_ops.assert_less(min_var, max_var)]): + return array_ops.fake_quant_with_min_max_vars( + inputs, + min_var, + max_var, + num_bits=num_bits, + narrow_range=narrow_range) |