aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-20 15:22:18 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 15:29:11 -0700
commitfe747c18093737065f482c15f14d95981fd55ef4 (patch)
treec5cc25358360c6fa39a7663bfed36c950d2b234d /tensorflow/contrib/lite/python
parent2e5887f7dc4855167bd833738be4bbfc36d6328e (diff)
Minor fixes to TocoConverter.
PiperOrigin-RevId: 209494423
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/convert.py31
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py6
3 files changed, 45 insertions, 3 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 11d4bdbe82..7378fcfe10 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -26,6 +26,7 @@ from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
@@ -223,7 +224,8 @@ def build_toco_convert_protos(input_tensors,
return model, toco
-def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs):
""""Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
@@ -252,3 +254,30 @@ def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
toco_flags.SerializeToString(),
input_data.SerializeToString())
return data
+
+
+@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ Typically this function is used to convert from TensorFlow GraphDef to TFLite.
+ Conversion can be customized by providing arguments that are forwarded to
+ `build_toco_convert_protos` (see documentation for details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ return toco_convert_impl(input_data, input_tensors, output_tensors, *args,
+ **kwargs)
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 5ec52035ad..2313bfa3b6 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -41,7 +41,8 @@ from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
-from tensorflow.contrib.lite.python.convert import toco_convert
+from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
from tensorflow.contrib.lite.python.convert_saved_model import get_tensors_from_tensor_names as _get_tensors_from_tensor_names
@@ -110,6 +111,7 @@ class TocoConverter(object):
Example usage:
+ ```python
# Converting a GraphDef from session.
converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
@@ -124,6 +126,11 @@ class TocoConverter(object):
# Converting a SavedModel.
converter = lite.TocoConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
+
+ # Converting a tf.keras model.
+ converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ tflite_model = converter.convert()
+ ```
"""
def __init__(self, graph_def, input_tensors, output_tensors):
@@ -354,7 +361,7 @@ class TocoConverter(object):
quantized_stats = None
# Converts model.
- result = toco_convert(
+ result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index a76cc39635..7d7a4ba94a 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -47,6 +47,9 @@ def _get_toco_converter(flags):
Returns:
TocoConverter object.
+
+ Raises:
+ ValueError: Invalid flags.
"""
# Parse input and output arrays.
input_arrays = _parse_array(flags.input_arrays)
@@ -77,6 +80,9 @@ def _get_toco_converter(flags):
elif flags.keras_model_file:
converter_fn = lite.TocoConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
+ else:
+ raise ValueError("--graph_def_file, --saved_model_dir, or "
+ "--keras_model_file must be specified.")
return converter_fn(**converter_kwargs)