aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/python
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-26 19:28:14 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-26 19:33:54 -0700
commit85258e06edf424492905fd032b02ff4d420b9da1 (patch)
treeba9e114333dbe20b5a68c188e4695d133b13c56a /tensorflow/contrib/lite/python
parent5b971c7eae5f2049a4725b16a4a44b688d3506b0 (diff)
Rename TocoConverter to TFLiteConverter.
PiperOrigin-RevId: 214710175
Diffstat (limited to 'tensorflow/contrib/lite/python')
-rw-r--r--tensorflow/contrib/lite/python/convert.py7
-rw-r--r--tensorflow/contrib/lite/python/lite.py94
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py171
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py12
4 files changed, 206 insertions, 78 deletions
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 1f48a826d4..627be8f44f 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -343,13 +343,14 @@ def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
return data
-@deprecation.deprecated(None, "Use `lite.TocoConverter` instead.")
+@deprecation.deprecated(None, "Use `lite.TFLiteConverter` instead.")
def toco_convert(input_data, input_tensors, output_tensors, *args, **kwargs):
- """"Convert a model using TOCO.
+ """Convert a model using TOCO.
Typically this function is used to convert from TensorFlow GraphDef to TFLite.
Conversion can be customized by providing arguments that are forwarded to
- `build_toco_convert_protos` (see documentation for details).
+ `build_toco_convert_protos` (see documentation for details). This function has
+ been deprecated. Please use `lite.TFLiteConverter` instead.
Args:
input_data: Input data (i.e. often `sess.graph_def`),
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2be24455d8..09365f101f 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -17,6 +17,7 @@
EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@TocoConverter
+@@TFLiteConverter
@@toco_convert
@@toco_convert_protos
@@Interpreter
@@ -62,9 +63,10 @@ from tensorflow.python.framework.importer import import_graph_def as _import_gra
from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
+from tensorflow.python.util import deprecation as _deprecation
-class TocoConverter(object):
+class TFLiteConverter(object):
"""Convert a TensorFlow model into `output_format` using TOCO.
This is used to convert from a TensorFlow GraphDef or SavedModel into either a
@@ -121,22 +123,22 @@ class TocoConverter(object):
```python
# Converting a GraphDef from session.
- converter = lite.TocoConverter.from_session(sess, in_tensors, out_tensors)
+ converter = lite.TFLiteConverter.from_session(sess, in_tensors, out_tensors)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a GraphDef from file.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
# Converting a SavedModel.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
# Converting a tf.keras model.
- converter = lite.TocoConverter.from_keras_model_file(keras_model)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_model)
tflite_model = converter.convert()
```
"""
@@ -147,10 +149,9 @@ class TocoConverter(object):
output_tensors,
input_arrays_with_shape=None,
output_arrays=None):
- """Constructor for TocoConverter.
+ """Constructor for TFLiteConverter.
Args:
-
graph_def: Frozen TensorFlow GraphDef.
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
@@ -158,8 +159,8 @@ class TocoConverter(object):
input_arrays_with_shape: Tuple of strings representing input tensor names
and list of integers representing input shapes
(e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
- into TensorFlow and when `input_tensors` and `output_tensors` are None.
- (default None)
+ into TensorFlow and when `input_tensors` and `output_tensors` are
+ None. (default None)
output_arrays: List of output tensors to freeze graph with. Use only when
graph cannot be loaded into TensorFlow and when `input_tensors` and
`output_tensors` are None. (default None)
@@ -195,7 +196,7 @@ class TocoConverter(object):
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
- """Creates a TocoConverter class from a TensorFlow Session.
+ """Creates a TFLiteConverter class from a TensorFlow Session.
Args:
sess: TensorFlow Session.
@@ -204,7 +205,7 @@ class TocoConverter(object):
output_tensors: List of output tensors (only .name is used from this).
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
graph_def = _freeze_graph(sess, output_tensors)
return cls(graph_def, input_tensors, output_tensors)
@@ -215,7 +216,7 @@ class TocoConverter(object):
input_arrays,
output_arrays,
input_shapes=None):
- """Creates a TocoConverter class from a file containing a frozen GraphDef.
+ """Creates a TFLiteConverter class from a file containing a frozen GraphDef.
Args:
graph_def_file: Full filepath of file containing frozen GraphDef.
@@ -224,10 +225,10 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
Raises:
IOError:
@@ -310,7 +311,7 @@ class TocoConverter(object):
output_arrays=None,
tag_set=None,
signature_key=None):
- """Creates a TocoConverter class from a SavedModel.
+ """Creates a TFLiteConverter class from a SavedModel.
Args:
saved_model_dir: SavedModel directory to convert.
@@ -319,7 +320,7 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
@@ -328,7 +329,7 @@ class TocoConverter(object):
(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
if tag_set is None:
tag_set = set([_tag_constants.SERVING])
@@ -346,7 +347,7 @@ class TocoConverter(object):
input_arrays=None,
input_shapes=None,
output_arrays=None):
- """Creates a TocoConverter class from a tf.keras model file.
+ """Creates a TFLiteConverter class from a tf.keras model file.
Args:
model_file: Full filepath of HDF5 file containing the tf.keras model.
@@ -355,12 +356,12 @@ class TocoConverter(object):
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" :
- None}). (default None)
+ None}). (default None)
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided. (default None)
Returns:
- TocoConverter class.
+ TFLiteConverter class.
"""
_keras.backend.clear_session()
_keras.backend.set_learning_phase(False)
@@ -502,6 +503,59 @@ class TocoConverter(object):
tensor.set_shape(shape)
+class TocoConverter(object):
+ """Convert a TensorFlow model into `output_format` using TOCO.
+
+ This class has been deprecated. Please use `lite.TFLiteConverter` instead.
+ """
+
+ @classmethod
+ @_deprecation.deprecated(None,
+ "Use `lite.TFLiteConverter.from_session` instead.")
+ def from_session(cls, sess, input_tensors, output_tensors):
+ """Creates a TocoConverter class from a TensorFlow Session."""
+ return TFLiteConverter.from_session(sess, input_tensors, output_tensors)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_frozen_graph` instead.")
+ def from_frozen_graph(cls,
+ graph_def_file,
+ input_arrays,
+ output_arrays,
+ input_shapes=None):
+ """Creates a TocoConverter class from a file containing a frozen graph."""
+ return TFLiteConverter.from_frozen_graph(graph_def_file, input_arrays,
+ output_arrays, input_shapes)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_saved_model` instead.")
+ def from_saved_model(cls,
+ saved_model_dir,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=None):
+ """Creates a TocoConverter class from a SavedModel."""
+ return TFLiteConverter.from_saved_model(saved_model_dir, input_arrays,
+ input_shapes, output_arrays,
+ tag_set, signature_key)
+
+ @classmethod
+ @_deprecation.deprecated(
+ None, "Use `lite.TFLiteConverter.from_keras_model_file` instead.")
+ def from_keras_model_file(cls,
+ model_file,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None):
+ """Creates a TocoConverter class from a tf.keras model file."""
+ return TFLiteConverter.from_keras_model_file(model_file, input_arrays,
+ input_shapes, output_arrays)
+
+
def _is_frozen_graph(sess):
"""Determines if the graph is frozen.
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index f112ed5cdd..33f8fc1e8c 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -50,18 +50,18 @@ class FromConstructor(test_util.TensorFlowTestCase):
# `output_arrays` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(
+ lite.TFLiteConverter(
None, None, [], input_arrays_with_shape=[('input', [3, 9])])
self.assertEqual(message, str(error.exception))
# `input_arrays_with_shape` is not defined.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter(None, [], None, output_arrays=['output'])
+ lite.TFLiteConverter(None, [], None, output_arrays=['output'])
self.assertEqual(message, str(error.exception))
# Tests valid constructors using a dummy value for the GraphDef.
def testValidConstructor(self):
- converter = lite.TocoConverter(
+ converter = lite.TFLiteConverter(
None,
None,
None,
@@ -76,7 +76,7 @@ class FromConstructor(test_util.TensorFlowTestCase):
'The batch size cannot be set for this model. Please use '
'input_shapes parameter.', str(error.exception))
- converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ converter = lite.TFLiteConverter(None, ['input_tensor'], ['output_tensor'])
self.assertTrue(converter._has_valid_tensors())
@@ -89,7 +89,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -121,7 +122,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {
@@ -166,7 +167,7 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(
+ converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1, in_tensor_2], [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'inputA': (0., 1.)} # mean, std_dev
@@ -182,7 +183,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual('Provide an input shape for input array \'Placeholder\'.',
@@ -195,7 +197,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Test invalid shape. None after 1st dimension.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
with self.assertRaises(ValueError) as error:
converter.convert()
self.assertEqual(
@@ -210,7 +213,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -242,7 +246,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess.run(_global_variables_initializer())
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -272,7 +277,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.output_format = lite_constants.GRAPHVIZ_DOT
graphviz_output = converter.convert()
self.assertTrue(graphviz_output)
@@ -285,7 +291,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
tflite_model = converter.convert()
@@ -299,7 +306,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
self.assertTrue(num_items_graphviz)
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
graphviz_dir = self.get_temp_dir()
converter.dump_graphviz_dir = graphviz_dir
converter.dump_graphviz_video = True
@@ -317,7 +325,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_input_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
tflite_model = converter.convert()
@@ -347,7 +356,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.inference_type = lite_constants.QUANTIZED_UINT8
converter.quantized_input_stats = {'Placeholder': (0., 1.)} # mean, std_dev
converter.default_ranges_stats = (0, 6) # min, max
@@ -387,13 +397,13 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert float model.
- float_converter = lite.TocoConverter.from_session(sess, [in_tensor_1],
- [out_tensor])
+ float_converter = lite.TFLiteConverter.from_session(sess, [in_tensor_1],
+ [out_tensor])
float_tflite = float_converter.convert()
self.assertTrue(float_tflite)
# Convert quantized weights model.
- quantized_converter = lite.TocoConverter.from_session(
+ quantized_converter = lite.TFLiteConverter.from_session(
sess, [in_tensor_1], [out_tensor])
quantized_converter.post_training_quantize = True
quantized_tflite = quantized_converter.convert()
@@ -409,7 +419,8 @@ class FromSessionTest(test_util.TensorFlowTestCase):
sess = session.Session()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ converter = lite.TFLiteConverter.from_session(sess, [in_tensor],
+ [out_tensor])
converter.converter_mode = lite.ConverterMode.TOCO_EXTENDED_ALL
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -424,6 +435,22 @@ class FromSessionTest(test_util.TensorFlowTestCase):
'sure you invoke the Eager delegate before inference.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ """Tests deprecated test TocoConverter."""
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ out_tensor = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_session(sess, [in_tensor], [out_tensor])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the interpreter is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromFrozenGraphFile(test_util.TensorFlowTestCase):
@@ -439,8 +466,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -474,7 +501,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
graph_def_file, ['Placeholder'], ['add'],
input_shapes={'Placeholder': [1, 16, 16, 3]})
tflite_model = converter.convert()
@@ -503,8 +530,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Ensure the graph with variables cannot be converted.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual('Please freeze the graph using freeze_graph.py.',
str(error.exception))
@@ -520,8 +547,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
sess.close()
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
- ['Placeholder'], ['add'])
+ converter = lite.TFLiteConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -545,8 +572,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
def testInvalidFileNotFound(self):
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph('invalid_file', ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph('invalid_file', ['Placeholder'],
+ ['add'])
self.assertEqual('File \'invalid_file\' does not exist.',
str(error.exception))
@@ -558,8 +585,8 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Attempts to convert the invalid model.
with self.assertRaises(IOError) as error:
- lite.TocoConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
- ['add'])
+ lite.TFLiteConverter.from_frozen_graph(graph_def_file, ['Placeholder'],
+ ['add'])
self.assertEqual(
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
@@ -580,7 +607,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Tests the object detection model that cannot be loaded in TensorFlow.
self._initObjectDetectionArgs()
- converter = lite.TocoConverter.from_frozen_graph(
+ converter = lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays,
self._input_shapes)
converter.allow_custom_ops = True
@@ -621,7 +648,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# Missing `input_shapes`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file, self._input_arrays, self._output_arrays)
self.assertEqual('input_shapes must be defined for this model.',
str(error.exception))
@@ -632,7 +659,7 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
# `input_shapes` does not contain the names in `input_arrays`.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_frozen_graph(
+ lite.TFLiteConverter.from_frozen_graph(
self._graph_def_file,
self._input_arrays,
self._output_arrays,
@@ -641,6 +668,27 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
'input_shapes must contain a value for each item in input_array.',
str(error.exception))
+ def testFloatTocoConverter(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ # Write graph to file.
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ sess.close()
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_frozen_graph(graph_def_file,
+ ['Placeholder'], ['add'])
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
@@ -663,7 +711,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Convert model and ensure model is not None.
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -693,7 +741,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel, with None in input tensor's shape."""
saved_model_dir = self._createSavedModel(shape=[None, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ converter = lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -724,7 +772,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
"""Test a SavedModel ordering of input arrays."""
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputB', 'inputA'])
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -757,7 +805,7 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
# Check case where input shape is given.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir,
input_arrays=['inputA'],
input_shapes={'inputA': [1, 16, 16, 3]})
@@ -766,12 +814,25 @@ class FromSavedModelTest(test_util.TensorFlowTestCase):
self.assertTrue(tflite_model)
# Check case where input shape is None.
- converter = lite.TocoConverter.from_saved_model(
+ converter = lite.TFLiteConverter.from_saved_model(
saved_model_dir, input_arrays=['inputA'], input_shapes={'inputA': None})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
+ def testSimpleModelTocoConverter(self):
+ """Test a SavedModel with deprecated TocoConverter."""
+ saved_model_dir = self._createSavedModel(shape=[1, 16, 16, 3])
+
+ # Convert model and ensure model is not None.
+ converter = lite.TocoConverter.from_saved_model(saved_model_dir)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
class FromKerasFile(test_util.TensorFlowTestCase):
@@ -805,7 +866,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
"""Test a Sequential tf.keras model with default inputs."""
keras_file = self._getSequentialModel()
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -845,13 +906,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid input array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['invalid-input'])
self.assertEqual("Invalid tensors 'invalid-input' were found.",
str(error.exception))
# Valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_arrays=['dense_input'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -863,13 +924,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Passing in shape of invalid input array has no impact as long as all input
# arrays have a shape.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'invalid-input': [2, 3]})
tflite_model = converter.convert()
self.assertTrue(tflite_model)
# Passing in shape of valid input array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, input_shapes={'dense_input': [2, 3]})
tflite_model = converter.convert()
os.remove(keras_file)
@@ -890,13 +951,13 @@ class FromKerasFile(test_util.TensorFlowTestCase):
# Invalid output array raises error.
with self.assertRaises(ValueError) as error:
- lite.TocoConverter.from_keras_model_file(
+ lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['invalid-output'])
self.assertEqual("Invalid tensors 'invalid-output' were found.",
str(error.exception))
# Valid output array.
- converter = lite.TocoConverter.from_keras_model_file(
+ converter = lite.TFLiteConverter.from_keras_model_file(
keras_file, output_arrays=['time_distributed/Reshape_1'])
tflite_model = converter.convert()
os.remove(keras_file)
@@ -926,7 +987,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -991,7 +1052,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1052,7 +1113,7 @@ class FromKerasFile(test_util.TensorFlowTestCase):
os.close(fd)
# Convert to TFLite model.
- converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ converter = lite.TFLiteConverter.from_keras_model_file(keras_file)
tflite_model = converter.convert()
self.assertTrue(tflite_model)
@@ -1086,6 +1147,18 @@ class FromKerasFile(test_util.TensorFlowTestCase):
np.testing.assert_almost_equal(tflite_result, keras_result, 5)
os.remove(keras_file)
+ def testSequentialModelTocoConverter(self):
+ """Test a Sequential tf.keras model with deprecated TocoConverter."""
+ keras_file = self._getSequentialModel()
+
+ converter = lite.TocoConverter.from_keras_model_file(keras_file)
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Ensure the model is able to load.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index c0ff7f37f9..d6d9052a4e 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -40,13 +40,13 @@ def _parse_set(values):
def _get_toco_converter(flags):
- """Makes a TocoConverter object based on the flags provided.
+ """Makes a TFLiteConverter object based on the flags provided.
Args:
flags: argparse.Namespace object containing TFLite flags.
Returns:
- TocoConverter object.
+ TFLiteConverter object.
Raises:
ValueError: Invalid flags.
@@ -68,17 +68,17 @@ def _get_toco_converter(flags):
"output_arrays": output_arrays
}
- # Create TocoConverter.
+ # Create TFLiteConverter.
if flags.graph_def_file:
- converter_fn = lite.TocoConverter.from_frozen_graph
+ converter_fn = lite.TFLiteConverter.from_frozen_graph
converter_kwargs["graph_def_file"] = flags.graph_def_file
elif flags.saved_model_dir:
- converter_fn = lite.TocoConverter.from_saved_model
+ converter_fn = lite.TFLiteConverter.from_saved_model
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
elif flags.keras_model_file:
- converter_fn = lite.TocoConverter.from_keras_model_file
+ converter_fn = lite.TFLiteConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
else:
raise ValueError("--graph_def_file, --saved_model_dir, or "