aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Mingxing Tan <tanmingxing@google.com>2018-06-23 20:40:06 -0700
committerGravatar Mingxing Tan <tanmingxing@google.com>2018-06-23 20:40:06 -0700
commit5c43b888d9e482f0b8ff5c5ff24cc0793f17a862 (patch)
tree630c9982edef4d97db02e33fa2b8f3e3c1c49161 /tensorflow/contrib/lite/python
parentcb401f09be5b816e704a70babc0facad63e84636 (diff)
parent1d569b8713bf2559c8ad0855bfd48219f38feea9 (diff)
Merge commit for internal changes
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/convert.py29
-rw-r--r--tensorflow/contrib/lite/python/lite.py10
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py35
3 files changed, 30 insertions, 44 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index c038c88945..0ea2630f71 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -25,7 +25,6 @@ import tempfile as _tempfile
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
-from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -135,11 +134,11 @@ def build_toco_convert_protos(input_tensors,
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
- inference_type: Target data type of arrays in the output file. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
- inference_input_type: Target data type of input arrays. Allows for a
- different type for input arrays in the case of quantization. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ inference_type: Target data type of real-number arrays in the output file.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of real-number input arrays. Allows
+ for a different type for input arrays in the case of quantization.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
output_format: Output file format. Currently must be `{TFLITE,
@@ -202,29 +201,13 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
+
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
- if input_tensor.dtype == _dtypes.float32:
- tflite_input_type = lite_constants.FLOAT
- elif input_tensor.dtype == _dtypes.int32:
- tflite_input_type = lite_constants.INT32
- elif input_tensor.dtype == _dtypes.int64:
- tflite_input_type = lite_constants.INT64
- elif input_tensor.dtype == _dtypes.uint8:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
- # TODO(aselle): Insert strings when they are available
- else:
- raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
- input_tensor.dtype))
-
input_array = model.input_arrays.add()
-
if inference_type == lite_constants.QUANTIZED_UINT8:
- if tflite_input_type == lite_constants.FLOAT:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
-
input_array.name = tensor_name(input_tensor)
input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 88dda7290b..69a2f638af 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -67,11 +67,11 @@ class TocoConverter(object):
Attributes:
- inference_type: Target data type of arrays in the output file. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
- inference_input_type: Target data type of input arrays. Allows for a
- different type for input arrays in the case of quantization. Currently
- must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
+ inference_type: Target data type of real-number arrays in the output file.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default FLOAT)
+ inference_input_type: Target data type of real-number input arrays. Allows
+ for a different type for input arrays in the case of quantization.
+ Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: Dict of strings representing input tensor names
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index f497533bed..d18a29834b 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -23,19 +23,15 @@ import os
import sys
from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
from tensorflow.python.platform import app
-def _parse_array(values):
+def _parse_array(values, type_fn=str):
if values:
- return values.split(",")
-
-
-def _parse_int_array(values):
- if values:
- return [int(val) for val in values.split(",")]
+ return [type_fn(val) for val in values.split(",") if val]
def _parse_set(values):
@@ -57,7 +53,8 @@ def _get_toco_converter(flags):
input_shapes = None
if flags.input_shapes:
input_shapes_list = [
- _parse_int_array(shape) for shape in flags.input_shapes.split(":")
+ _parse_array(shape, type_fn=int)
+ for shape in flags.input_shapes.split(":")
]
input_shapes = dict(zip(input_arrays, input_shapes_list))
output_arrays = _parse_array(flags.output_arrays)
@@ -103,8 +100,8 @@ def _convert_model(flags):
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
- std_dev_values = _parse_int_array(flags.std_dev_values)
- mean_values = _parse_int_array(flags.mean_values)
+ std_dev_values = _parse_array(flags.std_dev_values, type_fn=int)
+ mean_values = _parse_array(flags.mean_values, type_fn=int)
quant_stats = zip(mean_values, std_dev_values)
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
@@ -130,6 +127,9 @@ def _convert_model(flags):
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
if flags.quantize_weights:
+ if flags.inference_type == lite_constants.QUANTIZED_UINT8:
+ raise ValueError("--quantized_weights is not supported with "
+ "--inference_type=QUANTIZED_UINT8")
converter.quantize_weights = flags.quantize_weights
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
@@ -200,6 +200,9 @@ def _check_flags(flags, unparsed):
raise ValueError("--default_ranges_min and --default_ranges_max must be "
"used together")
+ if flags.dump_graphviz_video and not flags.dump_graphviz:
+ raise ValueError("--dump_graphviz_video must be used with --dump_graphviz")
+
def run_main(_):
"""Main in toco_convert.py."""
@@ -235,13 +238,13 @@ def run_main(_):
"--inference_type",
type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
- help="Target data type of arrays in the output file.")
+ help="Target data type of real-number arrays in the output file.")
parser.add_argument(
"--inference_input_type",
type=str.upper,
choices=["FLOAT", "QUANTIZED_UINT8"],
- help=("Target data type of input arrays. Allows for a different type for "
- "input arrays in the case of quantization."))
+ help=("Target data type of real-number input arrays. Allows for a "
+ "different type for input arrays in the case of quantization."))
# Input and output arrays flags.
parser.add_argument(
@@ -275,12 +278,12 @@ def run_main(_):
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
- "comma-separated. Used for quantization. (default None)"))
+ "comma-separated integers. Used for quantization. (default None)"))
parser.add_argument(
"--mean_values",
type=str,
- help=("Mean of training data for each input tensor, comma-separated. "
- "Used for quantization. (default None)"))
+ help=("Mean of training data for each input tensor, comma-separated "
+ "integers. Used for quantization. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=int,