aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python/lite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/python/lite.py')
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 44dfb97b84..2be24455d8 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,6 +40,7 @@ from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import ConverterMode
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
@@ -113,6 +114,8 @@ class TocoConverter(object):
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Example usage:
@@ -179,6 +182,7 @@ class TocoConverter(object):
self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ self.converter_mode = ConverterMode.DEFAULT
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@@ -389,6 +393,7 @@ class TocoConverter(object):
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
+ ConverterMode option is unsupported for the model.
"""
# Checks dimensions in input tensor.
if self._has_valid_tensors():
@@ -439,12 +444,18 @@ class TocoConverter(object):
# Converts model.
if self._has_valid_tensors():
+ converter_kwargs["converter_mode"] = self.converter_mode
result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
**converter_kwargs)
else:
+ # Graphs without valid tensors cannot be loaded into tf.Session since they
+ # contain TFLite operation(s) that cannot be resolved in TensorFlow.
+ if self.converter_mode != ConverterMode.DEFAULT:
+ raise ValueError("This model can only be converted with the default "
+ "converter.")
result = _toco_convert_graph_def(
input_data=self._graph_def,
input_arrays_with_shape=self._input_arrays_with_shape,