aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-06-22 15:07:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-22 15:10:19 -0700
commitd3931c804ce9619ac0a0c84f42b43ce70ade93a7 (patch)
tree7661d71fce9330eb5cd268f787c31105f0b1cb83
parentfdea22dbcc1a22f1dc76e3a2172591222ca165d4 (diff)
Updated toco_converter API.
PiperOrigin-RevId: 201748091
-rw-r--r--tensorflow/contrib/lite/python/convert.py19
1 files changed, 1 insertions, 18 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index c038c88945..b0c7614ad2 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -25,7 +25,6 @@ import tempfile as _tempfile
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
-from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -202,29 +201,13 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
+
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
for idx, input_tensor in enumerate(input_tensors):
- if input_tensor.dtype == _dtypes.float32:
- tflite_input_type = lite_constants.FLOAT
- elif input_tensor.dtype == _dtypes.int32:
- tflite_input_type = lite_constants.INT32
- elif input_tensor.dtype == _dtypes.int64:
- tflite_input_type = lite_constants.INT64
- elif input_tensor.dtype == _dtypes.uint8:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
- # TODO(aselle): Insert strings when they are available
- else:
- raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
- input_tensor.dtype))
-
input_array = model.input_arrays.add()
-
if inference_type == lite_constants.QUANTIZED_UINT8:
- if tflite_input_type == lite_constants.FLOAT:
- tflite_input_type = lite_constants.QUANTIZED_UINT8
input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
-
input_array.name = tensor_name(input_tensor)
input_array.shape.dims.extend(map(int, input_tensor.get_shape()))