aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Yunlu Li <yunluli@google.com>2018-08-30 17:19:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-30 17:24:24 -0700
commit9efc469cc5fa7eb0f8d8d8a7662cb6c1fbcb5b1a (patch)
treec01018a17eeb3c9397f4606c445e666af2d080e4 /tensorflow/contrib/lite/python
parent11f970c61bc21ae81b76cdee58871f0509ec9f0f (diff)
Change the data type of mean_values and std_dev_values to float.
PiperOrigin-RevId: 211010293
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/convert.py10
-rw-r--r--tensorflow/contrib/lite/python/lite.py6
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py1
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py15
4 files changed, 23 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 0b2192e031..69a3d562b3 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -149,9 +149,11 @@ def build_toco_convert_protos(input_tensors,
as `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
- quantized_input_stats: List of tuples of integers representing the mean and
+ quantized_input_stats: List of tuples of floats representing the mean and
standard deviation. Each tuple maps to the corresponding input tensor.
- Only need if `inference_type` is `QUANTIZED_UINT8`. (default None)
+ Only need if `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default None)
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
@@ -197,6 +199,8 @@ def build_toco_convert_protos(input_tensors,
toco.inference_type = inference_type
if inference_input_type:
toco.inference_input_type = inference_input_type
+ else:
+ toco.inference_input_type = toco.inference_type
toco.drop_control_dependency = drop_control_dependency
toco.reorder_across_fake_quant = reorder_across_fake_quant
toco.allow_custom_ops = allow_custom_ops
@@ -212,7 +216,7 @@ def build_toco_convert_protos(input_tensors,
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
input_array = model.input_arrays.add()
- if inference_type == lite_constants.QUANTIZED_UINT8:
+ if toco.inference_input_type == lite_constants.QUANTIZED_UINT8:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
input_array.name = tensor_name(input_tensor)
if input_shapes is None:
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index a4c9a2381c..80cbb12825 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -78,9 +78,11 @@ class TocoConverter(object):
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
- mapped to tuple of integers representing the mean and standard deviation
+ mapped to tuple of floats representing the mean and standard deviation
of the training data (e.g., {"foo" : (0., 1.)}). Only need if
- `inference_type` is `QUANTIZED_UINT8`. (default {})
+ `inference_input_type` is `QUANTIZED_UINT8`.
+ real_input_value = (quantized_input_value - mean_value) / std_dev_value.
+ (default {})
default_ranges_stats: Tuple of integers representing (min, max) range values
for all arrays without a specified range. Intended for experimenting with
quantization via "dummy quantization". (default None)
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 8c9cfa943f..e6aa8b0d99 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -319,6 +319,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Convert model and ensure model is not None.
converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
+ converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
self.assertTrue(tflite_model)
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index ce12a9abde..dc078ffd21 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -109,8 +109,14 @@ def _convert_model(flags):
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
- std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
- mean_values = _parse_array(flags.mean_values, type_fn=int)
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
+
+ # In quantized inference, mean_value has to be integer so that the real
+ # value 0.0 is exactly representable.
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
+ else:
+ mean_values = _parse_array(flags.mean_values, type_fn=float)
quant_stats = list(zip(mean_values, std_dev_values))
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
@@ -293,12 +299,13 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated integers. Used for quantization. (default None)"))
+ "comma-separated floats. Used for quantized input tensors. "
+ "(default None)"))
parser.add_argument(
"--mean_values",
type=str,
help=("Mean of training data for each input tensor, comma-separated "
- "integers. Used for quantization. (default None)"))
+ "floats. Used for quantized input tensors. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,