diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-06-22 15:07:08 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-22 15:10:19 -0700 |
commit | d3931c804ce9619ac0a0c84f42b43ce70ade93a7 (patch) | |
tree | 7661d71fce9330eb5cd268f787c31105f0b1cb83 | |
parent | fdea22dbcc1a22f1dc76e3a2172591222ca165d4 (diff) |
Updated toco_converter API.
PiperOrigin-RevId: 201748091
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 19 |
1 files changed, 1 insertions, 18 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index c038c88945..b0c7614ad2 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -25,7 +25,6 @@ import tempfile as _tempfile from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 -from tensorflow.python.framework import dtypes as _dtypes from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util.lazy_loader import LazyLoader @@ -202,29 +201,13 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges for idx, input_tensor in enumerate(input_tensors): - if input_tensor.dtype == _dtypes.float32: - tflite_input_type = lite_constants.FLOAT - elif input_tensor.dtype == _dtypes.int32: - tflite_input_type = lite_constants.INT32 - elif input_tensor.dtype == _dtypes.int64: - tflite_input_type = lite_constants.INT64 - elif input_tensor.dtype == _dtypes.uint8: - tflite_input_type = lite_constants.QUANTIZED_UINT8 - # TODO(aselle): Insert strings when they are available - else: - raise ValueError("Tensors %s not known type %r" % (input_tensor.name, - input_tensor.dtype)) - input_array = model.input_arrays.add() - if inference_type == lite_constants.QUANTIZED_UINT8: - if tflite_input_type == lite_constants.FLOAT: - tflite_input_type = lite_constants.QUANTIZED_UINT8 input_array.mean_value, input_array.std_value = quantized_input_stats[idx] - input_array.name = tensor_name(input_tensor) input_array.shape.dims.extend(map(int, input_tensor.get_shape())) |