diff options
author | 2018-08-20 15:22:18 -0700 | |
---|---|---|
committer | 2018-08-20 15:29:11 -0700 | |
commit | fe747c18093737065f482c15f14d95981fd55ef4 (patch) | |
tree | c5cc25358360c6fa39a7663bfed36c950d2b234d /tensorflow/contrib/lite/python | |
parent | 2e5887f7dc4855167bd833738be4bbfc36d6328e (diff) |
Minor fixes to TocoConverter.
PiperOrigin-RevId: 209494423
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 31 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 6 |
3 files changed, 45 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 11d4bdbe82..7378fcfe10 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -26,6 +26,7 @@ from tensorflow.contrib.lite.python import lite_constants from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2 from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2 from tensorflow.python.platform import resource_loader as _resource_loader +from tensorflow.python.util import deprecation from tensorflow.python.util.lazy_loader import LazyLoader @@ -223,7 +224,8 @@ def build_toco_convert_protos(input_tensors, return model, toco -def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): +def toco_convert_impl(input_data, input_tensors, output_tensors, *args, + **kwargs): """"Convert a model using TOCO. Typically this function is used to convert from TensorFlow GraphDef to TFLite. @@ -252,3 +254,30 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): toco_flags.SerializeToString(), input_data.SerializeToString()) return data + + +@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.") +def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs): + """"Convert a model using TOCO. + + Typically this function is used to convert from TensorFlow GraphDef to TFLite. + Conversion can be customized by providing arguments that are forwarded to + `build_toco_convert_protos` (see documentation for details). + + Args: + input_data: Input data (i.e. often `sess.graph_def`), + input_tensors: List of input tensors. Type and shape are computed using + `foo.get_shape()` and `foo.dtype`. + output_tensors: List of output tensors (only .name is used from this). + *args: See `build_toco_convert_protos`, + **kwargs: See `build_toco_convert_protos`. + + Returns: + The converted data. For example if TFLite was the destination, then + this will be a tflite flatbuffer in a bytes array. + + Raises: + Defined in `build_toco_convert_protos`. + """ + return toco_convert_impl(input_data, input_tensors, output_tensors, *args, + **kwargs) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 5ec52035ad..2313bfa3b6 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -41,7 +41,8 @@ from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name -from tensorflow.contrib.lite.python.convert import toco_convert +from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names @@ -110,6 +111,7 @@ class TocoConverter(object): Example usage: + ```python # Converting a GraphDef from session. converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors) tflite_model = converter.convert() @@ -124,6 +126,11 @@ class TocoConverter(object): # Converting a SavedModel. converter = lite.TocoConverter.from_saved_model(saved_model_dir) tflite_model = converter.convert() + + # Converting a tf.keras model. + converter = lite.TocoConverter.from_keras_model_file(keras_model) + tflite_model = converter.convert() + ``` """ def __init__(self, graph_def, input_tensors, output_tensors): @@ -354,7 +361,7 @@ class TocoConverter(object): quantized_stats = None # Converts model. - result = toco_convert( + result = _toco_convert_impl( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index a76cc39635..7d7a4ba94a 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -47,6 +47,9 @@ def _get_toco_converter(flags): Returns: TocoConverter object. + + Raises: + ValueError: Invalid flags. """ # Parse input and output arrays. input_arrays = _parse_array(flags.input_arrays) @@ -77,6 +80,9 @@ def _get_toco_converter(flags): elif flags.keras_model_file: converter_fn = lite.TocoConverter.from_keras_model_file converter_kwargs["model_file"] = flags.keras_model_file + else: + raise ValueError("--graph_def_file, --saved_model_dir, or " + "--keras_model_file must be specified.") return converter_fn(**converter_kwargs) |