diff options
author | Nupur Garg <nupurgarg@google.com> | 2018-09-13 13:08:26 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-13 13:16:56 -0700 |
commit | 2646bf2d2bfb717c828db6391563b431f760a7d3 (patch) | |
tree | e462fa3d8e43e6e2aea55f0b188ce393b2105d14 | |
parent | cdc7f0fbce230b5eef30b6f0049495af3983aea0 (diff) |
Internal change.
PiperOrigin-RevId: 212864677
-rw-r--r-- | tensorflow/contrib/lite/python/convert.py | 43 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite.py | 11 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/lite_test.py | 22 | ||||
-rw-r--r-- | tensorflow/contrib/lite/python/tflite_convert.py | 11 |
4 files changed, 82 insertions, 5 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py index 1c5516ae7c..1f48a826d4 100644 --- a/tensorflow/contrib/lite/python/convert.py +++ b/tensorflow/contrib/lite/python/convert.py @@ -18,6 +18,8 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import enum # pylint: disable=g-bad-import-order + import os as _os import platform as _platform import subprocess as _subprocess @@ -30,7 +32,6 @@ from tensorflow.python.platform import resource_loader as _resource_loader from tensorflow.python.util import deprecation from tensorflow.python.util.lazy_loader import LazyLoader - # Lazy load since some of the performance benchmark skylark rules # break dependencies. _toco_python = LazyLoader( @@ -52,6 +53,31 @@ if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin): _toco_from_proto_bin = "toco_from_protos" +class ConverterMode(enum.Enum): + """Enum class defining the converters available to generate TFLite models. + + WARNING: Experimental interface, subject to change. + """ + # Convert model using TOCO such that all ops are TensorFlow Lite native ops. + # + # This is the only supported mode for any models that contain operations that + # cannot be resolved in TensorFlow. + DEFAULT = "DEFAULT" + + # Convert model using TOCO such that only unsupported operations are + # represented as TensorFlow ops. + # WARNING: Experimental interface, subject to change. + TOCO_EXTENDED = "TOCO_EXTENDED" + + # Convert model using TOCO such that all operations are represented as + # TensorFlow ops. + # WARNING: Experimental interface, subject to change. + TOCO_EXTENDED_ALL = "TOCO_EXTENDED_ALL" + + def __str__(self): + return self.value + + def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str): """Convert `input_data_str` according to model and toco parameters. @@ -128,7 +154,8 @@ def build_toco_convert_protos(input_tensors, change_concat_input_ranges=False, post_training_quantize=False, dump_graphviz_dir=None, - dump_graphviz_video=False): + dump_graphviz_video=False, + converter_mode=ConverterMode.DEFAULT): """Builds protocol buffers describing a conversion of a model using TOCO. Typically this is to convert from TensorFlow GraphDef to TFLite, in which @@ -183,6 +210,8 @@ def build_toco_convert_protos(input_tensors, output file. (default None) dump_graphviz_video: Boolean indicating whether to dump the graph after every graph transformation. (default False) + converter_mode: Experimental flag, subject to change. ConverterMode + indicating which converter to use. (default ConverterMode.DEFAULT) Returns: model_flags, toco_flags: two protocol buffers describing the conversion @@ -211,6 +240,11 @@ def build_toco_convert_protos(input_tensors, if dump_graphviz_dir: toco.dump_graphviz_dir = dump_graphviz_dir toco.dump_graphviz_include_video = dump_graphviz_video + if converter_mode == ConverterMode.TOCO_EXTENDED: + toco.allow_eager_ops = True + elif converter_mode == ConverterMode.TOCO_EXTENDED_ALL: + toco.allow_eager_ops = True + toco.force_eager_ops = True model = _model_flags_pb2.ModelFlags() model.change_concat_input_ranges = change_concat_input_ranges @@ -301,9 +335,8 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args, Raises: Defined in `build_toco_convert_protos`. """ - model_flags, toco_flags = build_toco_convert_protos(input_tensors, - output_tensors, - *args, **kwargs) + model_flags, toco_flags = build_toco_convert_protos( + input_tensors, output_tensors, *args, **kwargs) data = toco_convert_protos(model_flags.SerializeToString(), toco_flags.SerializeToString(), input_data.SerializeToString()) diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py index 44dfb97b84..2be24455d8 100644 --- a/tensorflow/contrib/lite/python/lite.py +++ b/tensorflow/contrib/lite/python/lite.py @@ -40,6 +40,7 @@ from google.protobuf import text_format as _text_format from google.protobuf.message import DecodeError from tensorflow.contrib.lite.python import lite_constants as constants from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import +from tensorflow.contrib.lite.python.convert import ConverterMode from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def @@ -113,6 +114,8 @@ class TocoConverter(object): output file. (default None) dump_graphviz_video: Boolean indicating whether to dump the graph after every graph transformation. (default False) + converter_mode: Experimental flag, subject to change. ConverterMode + indicating which converter to use. (default ConverterMode.DEFAULT) Example usage: @@ -179,6 +182,7 @@ class TocoConverter(object): self.post_training_quantize = False self.dump_graphviz_dir = None self.dump_graphviz_video = False + self.converter_mode = ConverterMode.DEFAULT # Attributes are used by models that cannot be loaded into TensorFlow. if not self._has_valid_tensors(): @@ -389,6 +393,7 @@ class TocoConverter(object): ValueError: Input shape is not specified. None value for dimension in input_tensor. + ConverterMode option is unsupported for the model. """ # Checks dimensions in input tensor. if self._has_valid_tensors(): @@ -439,12 +444,18 @@ class TocoConverter(object): # Converts model. if self._has_valid_tensors(): + converter_kwargs["converter_mode"] = self.converter_mode result = _toco_convert_impl( input_data=self._graph_def, input_tensors=self._input_tensors, output_tensors=self._output_tensors, **converter_kwargs) else: + # Graphs without valid tensors cannot be loaded into tf.Session since they + # contain TFLite operation(s) that cannot be resolved in TensorFlow. + if self.converter_mode != ConverterMode.DEFAULT: + raise ValueError("This model can only be converted with the default " + "converter.") result = _toco_convert_graph_def( input_data=self._graph_def, input_arrays_with_shape=self._input_arrays_with_shape, diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py index 3f8ea433ff..f112ed5cdd 100644 --- a/tensorflow/contrib/lite/python/lite_test.py +++ b/tensorflow/contrib/lite/python/lite_test.py @@ -402,6 +402,28 @@ class FromSessionTest(test_util.TensorFlowTestCase): # Ensure that the quantized weights tflite model is smaller. self.assertTrue(len(quantized_tflite) < len(float_tflite)) + def testExtendedMode(self): + in_tensor = array_ops.placeholder( + shape=[1, 16, 16, 3], dtype=dtypes.float32) + out_tensor = in_tensor + in_tensor + sess = session.Session() + + # Convert model and ensure model is not None. + converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor]) + converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL + tflite_model = converter.convert() + self.assertTrue(tflite_model) + + # Ensures the model contains TensorFlow ops. + # TODO(nupurgarg): Check values once there is a Python delegate interface. + interpreter = Interpreter(model_content=tflite_model) + with self.assertRaises(RuntimeError) as error: + interpreter.allocate_tensors() + self.assertIn( + 'Regular TensorFlow ops are not supported by this interpreter. Make ' + 'sure you invoke the Eager delegate before inference.', + str(error.exception)) + class FromFrozenGraphFile(test_util.TensorFlowTestCase): diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py index cc08ed3fe9..c0ff7f37f9 100644 --- a/tensorflow/contrib/lite/python/tflite_convert.py +++ b/tensorflow/contrib/lite/python/tflite_convert.py @@ -140,8 +140,11 @@ def _convert_model(flags): if flags.change_concat_input_ranges: converter.change_concat_input_ranges = ( flags.change_concat_input_ranges == "TRUE") + if flags.allow_custom_ops: converter.allow_custom_ops = flags.allow_custom_ops + if flags.converter_mode: + converter.converter_mode = flags.converter_mode if flags.post_training_quantize: converter.post_training_quantize = flags.post_training_quantize @@ -363,6 +366,8 @@ def run_main(_): help=("Boolean to change behavior of min/max ranges for inputs and " "outputs of the concat operator for quantized models. Changes the " "ranges of concat operator overlap when true. (default False)")) + + # Permitted ops flags. parser.add_argument( "--allow_custom_ops", action="store_true", @@ -371,6 +376,12 @@ def run_main(_): "created for any op that is unknown. The developer will need to " "provide these to the TensorFlow Lite runtime with a custom " "resolver. (default False)")) + parser.add_argument( + "--converter_mode", + type=lite.ConverterMode, + choices=list(lite.ConverterMode), + help=("Experimental flag, subject to change. ConverterMode indicating " + "which converter to use. (default ConverterMode.DEFAULT)")) # Logging flags. parser.add_argument( |