aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/convert.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/convert.py')
-rw-r--r--tensorflow/contrib/lite/python/convert.py9
1 files changed, 8 insertions, 1 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 0ea2630f71..ec49738fb5 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -115,6 +115,7 @@ def build_toco_convert_protos(input_tensors,
inference_type=lite_constants.FLOAT,
inference_input_type=None,
input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ input_shapes=None,
output_format=lite_constants.TFLITE,
quantized_input_stats=None,
default_ranges_stats=None,
@@ -141,6 +142,8 @@ def build_toco_convert_protos(input_tensors,
Must be `{FLOAT, QUANTIZED_UINT8}`. (default `inference_type`)
input_format: Type of data to read Currently must be
`{TENSORFLOW_GRAPHDEF}`. (default TENSORFLOW_GRAPHDEF)
+ input_shapes: Input array shape. It needs to be a list of the same length
+ as `input_tensors`, or None. (default None)
output_format: Output file format. Currently must be `{TFLITE,
GRAPHVIZ_DOT}`. (default TFLITE)
quantized_input_stats: List of tuples of integers representing the mean and
@@ -209,7 +212,11 @@ def build_toco_convert_protos(input_tensors,
if inference_type == lite_constants.QUANTIZED_UINT8:
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
input_array.name = tensor_name(input_tensor)
- input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
+ if input_shapes is None:
+ shape = input_tensor.get_shape()
+ else:
+ shape = input_shapes[idx]
+ input_array.shape.dims.extend(map(int, shape))
for output_tensor in output_tensors:
model.output_arrays.append(tensor_name(output_tensor))