diff options
author | 2018-06-23 20:40:06 -0700 | |
---|---|---|
committer | 2018-06-23 20:40:06 -0700 | |
commit | 5c43b888d9e482f0b8ff5c5ff24cc0793f17a862 (patch) | |
tree | 630c9982edef4d97db02e33fa2b8f3e3c1c49161 /tensorflow/contrib/lite/python | |
parent | cb401f09be5b816e704a70babc0facad63e84636 (diff) | |
parent | 1d569b8713bf2559c8ad0855bfd48219f38feea9 (diff) |
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 29 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 35 |
3 files changed, 30 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c038c88945..0ea2630f71 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -25,7 +25,6 @@ import tempfile as _tempfile from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.lazy_loader import LazyLoader @@ -135,11 +134,11 @@ def build_toco_convert_protos(input_tensors, input_tensors: List of input tensors. Type and shape are computed using `foo.get_shape()` and `foo.dtype`. output_tensors: List of output tensors (only .name is used from this). - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) input_format: Type of data to read Currently must be `{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF) output_format: Output file format. Currently must be `{TFLITE, @@ -202,29 +201,13 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = lite_constants.FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = lite_constants.INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = lite_constants.INT64 - elif input_tensor.dtype == _dtypes.uint8: - tflite_input_type = lite_constants.QUANTIZED_UINT8 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - input_array = model.input_arrays.add() - if inference_type == lite_constants.QUANTIZED_UINT8: - if tflite_input_type == lite_constants.FLOAT: - tflite_input_type = lite_constants.QUANTIZED_UINT8 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - input_array.name = tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 88dda7290b..69a2f638af 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -67,11 +67,11 @@ class TocoConverter(object): Attributes: - inference_type: Target data type of arrays in the output file. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) - inference_input_type: Target data type of input arrays. Allows for a - different type for input arrays in the case of quantization. Currently - must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) + inference_type: Target data type of real-number arrays in the output file. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT) + inference_input_type: Target data type of real-number input arrays. Allows + for a different type for input arrays in the case of quantization. + Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`) output_format: Output file format. Currently must be `{TFLITE, GRAPHVIZ_DOT}`. (default TFLITE) quantized_input_stats: Dict of strings representing input tensor names diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index f497533bed..d18a29834b 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -23,19 +23,15 @@ import os import sys from tensorflow.contrib.lite.python import lite +from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2 from tensorflow.python.platform import app -def _parse_array(values): +def _parse_array(values, type_fn=str): if values: - return values.split(",") - - -def _parse_int_array(values): - if values: - return [int(val) for val in values.split(",")] + return [type_fn(val) for val in values.split(",") if val] def _parse_set(values): @@ -57,7 +53,8 @@ def _get_toco_converter(flags): input_shapes = None if flags.input_shapes: input_shapes_list = [ - _parse_int_array(shape) for shape in flags.input_shapes.split(":") + _parse_array(shape, type_fn=int) + for shape in flags.input_shapes.split(":") ] input_shapes = dict(zip(input_arrays, input_shapes_list)) output_arrays = _parse_array(flags.output_arrays) @@ -103,8 +100,8 @@ def _convert_model(flags): if flags.mean_values and flags.std_dev_values: input_arrays = converter.get_input_arrays() - std_dev_values = _parse_int_array(flags.std_dev_values) - mean_values = _parse_int_array(flags.mean_values) + std_dev_values = _parse_array(flags.std_dev_values, type_fn=int) + mean_values = _parse_array(flags.mean_values, type_fn=int) quant_stats = zip(mean_values, std_dev_values) if ((not flags.input_arrays and len(input_arrays) > 1) or (len(input_arrays) != len(quant_stats))): @@ -130,6 +127,9 @@ def _convert_model(flags): if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops if flags.quantize_weights: + if flags.inference_type == lite_constants.QUANTIZED_UINT8: + raise ValueError("--quantized_weights is not supported with " + "--inference_type=QUANTIZED_UINT8") converter.quantize_weights = flags.quantize_weights if flags.dump_graphviz_dir: converter.dump_graphviz_dir = flags.dump_graphviz_dir @@ -200,6 +200,9 @@ def _check_flags(flags, unparsed): raise ValueError("--default_ranges_min and --default_ranges_max must be " "used together") + if flags.dump_graphviz_video and not flags.dump_graphviz: + raise ValueError("--dump_graphviz_video must be used with --dump_graphviz") + def run_main(_): """Main in toco_convert.py.""" @@ -235,13 +238,13 @@ def run_main(_): "--inference_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help="Target data type of arrays in the output file.") + help="Target data type of real-number arrays in the output file.") parser.add_argument( "--inference_input_type", type=str.upper, choices=["FLOAT", "QUANTIZED_UINT8"], - help=("Target data type of input arrays. Allows for a different type for " - "input arrays in the case of quantization.")) + help=("Target data type of real-number input arrays. Allows for a " + "different type for input arrays in the case of quantization.")) # Input and output arrays flags. parser.add_argument( @@ -275,12 +278,12 @@ def run_main(_): "--std_dev_values", type=str, help=("Standard deviation of training data for each input tensor, " - "comma-separated. Used for quantization. (default None)")) + "comma-separated integers. Used for quantization. (default None)")) parser.add_argument( "--mean_values", type=str, - help=("Mean of training data for each input tensor, comma-separated. " - "Used for quantization. (default None)")) + help=("Mean of training data for each input tensor, comma-separated " + "integers. Used for quantization. (default None)")) parser.add_argument( "--default_ranges_min", type=int, |