aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-13 13:08:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-13 13:16:56 -0700
commit2646bf2d2bfb717c828db6391563b431f760a7d3 (patch)
treee462fa3d8e43e6e2aea55f0b188ce393b2105d14
parentcdc7f0fbce230b5eef30b6f0049495af3983aea0 (diff)
Internal change.
PiperOrigin-RevId: 212864677
-rw-r--r--tensorflow/contrib/lite/python/convert.py43
-rw-r--r--tensorflow/contrib/lite/python/lite.py11
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py22
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py11
4 files changed, 82 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1c5516ae7c..1f48a826d4 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import enum # pylint: disable=g-bad-import-order
+
import os as _os
import platform as _platform
import subprocess as _subprocess
@@ -30,7 +32,6 @@ from tensorflow.python.platform import resource_loader as _resource_loader
from tensorflow.python.util import deprecation
from tensorflow.python.util.lazy_loader import LazyLoader
-
# Lazy load since some of the performance benchmark skylark rules
# break dependencies.
_toco_python = LazyLoader(
@@ -52,6 +53,31 @@ if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
_toco_from_proto_bin = "toco_from_protos"
+class ConverterMode(enum.Enum):
+ """Enum class defining the converters available to generate TFLite models.
+
+ WARNING: Experimental interface, subject to change.
+ """
+ # Convert model using TOCO such that all ops are TensorFlow Lite native ops.
+ #
+ # This is the only supported mode for any models that contain operations that
+ # cannot be resolved in TensorFlow.
+ DEFAULT = "DEFAULT"
+
+ # Convert model using TOCO such that only unsupported operations are
+ # represented as TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED = "TOCO_EXTENDED"
+
+ # Convert model using TOCO such that all operations are represented as
+ # TensorFlow ops.
+ # WARNING: Experimental interface, subject to change.
+ TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL"
+
+ def __str__(self):
+ return self.value
+
+
def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
"""Convert `input_data_str` according to model and toco parameters.
@@ -128,7 +154,8 @@ def build_toco_convert_protos(input_tensors,
change_concat_input_ranges=False,
post_training_quantize=False,
dump_graphviz_dir=None,
- dump_graphviz_video=False):
+ dump_graphviz_video=False,
+ converter_mode=ConverterMode.DEFAULT):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -183,6 +210,8 @@ def build_toco_convert_protos(input_tensors,
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Returns:
model_flags, toco_flags: two protocol buffers describing the conversion
@@ -211,6 +240,11 @@ def build_toco_convert_protos(input_tensors,
if dump_graphviz_dir:
toco.dump_graphviz_dir = dump_graphviz_dir
toco.dump_graphviz_include_video = dump_graphviz_video
+ if converter_mode == ConverterMode.TOCO_EXTENDED:
+ toco.allow_eager_ops = True
+ elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL:
+ toco.allow_eager_ops = True
+ toco.force_eager_ops = True
model = _model_flags_pb2.ModelFlags()
model.change_concat_input_ranges = change_concat_input_ranges
@@ -301,9 +335,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
Raises:
Defined in `build_toco_convert_protos`.
"""
- model_flags, toco_flags = build_toco_convert_protos(input_tensors,
- output_tensors,
- *args, **kwargs)
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors, output_tensors, *args, **kwargs)
data = toco_convert_protos(model_flags.SerializeToString(),
toco_flags.SerializeToString(),
input_data.SerializeToString())
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 44dfb97b84..2be24455d8 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -40,6 +40,7 @@ from google.protobuf import text_format as _text_format
from google.protobuf.message import DecodeError
from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import ConverterMode
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
@@ -113,6 +114,8 @@ class TocoConverter(object):
output file. (default None)
dump_graphviz_video: Boolean indicating whether to dump the graph after
every graph transformation. (default False)
+ converter_mode: Experimental flag, subject to change. ConverterMode
+ indicating which converter to use. (default ConverterMode.DEFAULT)
Example usage:
@@ -179,6 +182,7 @@ class TocoConverter(object):
self.post_training_quantize = False
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ self.converter_mode = ConverterMode.DEFAULT
# Attributes are used by models that cannot be loaded into TensorFlow.
if not self._has_valid_tensors():
@@ -389,6 +393,7 @@ class TocoConverter(object):
ValueError:
Input shape is not specified.
None value for dimension in input_tensor.
+ ConverterMode option is unsupported for the model.
"""
# Checks dimensions in input tensor.
if self._has_valid_tensors():
@@ -439,12 +444,18 @@ class TocoConverter(object):
# Converts model.
if self._has_valid_tensors():
+ converter_kwargs["converter_mode"] = self.converter_mode
result = _toco_convert_impl(
input_data=self._graph_def,
input_tensors=self._input_tensors,
output_tensors=self._output_tensors,
**converter_kwargs)
else:
+ # Graphs without valid tensors cannot be loaded into tf.Session since they
+ # contain TFLite operation(s) that cannot be resolved in TensorFlow.
+ if self.converter_mode != ConverterMode.DEFAULT:
+ raise ValueError("This model can only be converted with the default "
+ "converter.")
result = _toco_convert_graph_def(
input_data=self._graph_def,
input_arrays_with_shape=self._input_arrays_with_shape,
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 3f8ea433ff..f112ed5cdd 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -402,6 +402,28 @@ class FromSessionTest(test_util.TensorFlowTestCase):
# Ensure that the quantized weights tflite model is smaller.
self.assertTrue(len(quantized_tflite) < len(float_tflite))
+ def testExtendedMode(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensures the model contains TensorFlow ops.
+ # TODO(nupurgarg): Check values once there is a Python delegate interface.
+ interpreter = Interpreter(model_content=tflite_model)
+ with self.assertRaises(RuntimeError) as error:
+ interpreter.allocate_tensors()
+ self.assertIn(
+ 'Regular TensorFlow ops are not supported by this interpreter. Make '
+ 'sure you invoke the Eager delegate before inference.',
+ str(error.exception))
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index cc08ed3fe9..c0ff7f37f9 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -140,8 +140,11 @@ def _convert_model(flags):
if flags.change_concat_input_ranges:
converter.change_concat_input_ranges = (
flags.change_concat_input_ranges == "TRUE")
+
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
+ if flags.converter_mode:
+ converter.converter_mode = flags.converter_mode
if flags.post_training_quantize:
converter.post_training_quantize = flags.post_training_quantize
@@ -363,6 +366,8 @@ def run_main(_):
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
+
+ # Permitted ops flags.
parser.add_argument(
"--allow_custom_ops",
action="store_true",
@@ -371,6 +376,12 @@ def run_main(_):
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
+ parser.add_argument(
+ "--converter_mode",
+ type=lite.ConverterMode,
+ choices=list(lite.ConverterMode),
+ help=("Experimental flag, subject to change. ConverterMode indicating "
+ "which converter to use. (default ConverterMode.DEFAULT)"))
# Logging flags.
parser.add_argument(