aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-08-28 18:16:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-28 18:23:02 -0700
commit2e7352e57c541908cd700bb0fe53a04b456392c9 (patch)
tree2064e341d1a7f154d0cb1d42910359c7dd3e5a02 /tensorflow
parentc4099e6ee8ba3846f2b7e70445806bc3055c5624 (diff)
Add more model support to TocoConverter.
PiperOrigin-RevId: 210643904
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/lite/python/BUILD3
-rw-r--r--tensorflow/contrib/lite/python/convert.py48
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py89
-rw-r--r--tensorflow/contrib/lite/python/lite.py171
-rw-r--r--tensorflow/contrib/lite/python/lite_test.py113
-rw-r--r--tensorflow/contrib/lite/python/tflite_convert.py10
-rwxr-xr-xtensorflow/workspace.bzl12
7 files changed, 393 insertions, 53 deletions
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 47f0c8e9a2..6e30251eff 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -70,7 +70,7 @@ py_library(
py_test(
name = "lite_test",
srcs = ["lite_test.py"],
- data = [":interpreter_test_data"],
+ data = ["@tflite_mobilenet_ssd_quant_protobuf//:tflite_graph.pbtxt"],
srcs_version = "PY2AND3",
tags = [
"no_oss",
@@ -130,6 +130,7 @@ py_test(
],
deps = [
":convert",
+ ":interpreter",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 12cc66dc55..0b2192e031 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -226,6 +226,54 @@ def build_toco_convert_protos(input_tensors,
return model, toco
+def toco_convert_graph_def(input_data, input_arrays_with_shape, output_arrays,
+ *args, **kwargs):
+ """"Convert a model using TOCO.
+
+ This function is used to convert GraphDefs that cannot be loaded into
+ TensorFlow to TFLite. Conversion can be customized by providing arguments
+ that are forwarded to `build_toco_convert_protos` (see documentation for
+ details).
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`),
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` is None. (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `output_tensors` is None.
+ (default None)
+ *args: See `build_toco_convert_protos`,
+ **kwargs: See `build_toco_convert_protos`.
+
+ Returns:
+ The converted data. For example if TFLite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ Defined in `build_toco_convert_protos`.
+ """
+ model_flags, toco_flags = build_toco_convert_protos(
+ input_tensors=[], output_tensors=[], *args, **kwargs)
+
+ for idx, (name, shape) in enumerate(input_arrays_with_shape):
+ input_array = model_flags.input_arrays.add()
+ if kwargs["inference_type"] == lite_constants.QUANTIZED_UINT8:
+ input_array.mean_value, input_array.std_value = kwargs[
+ "quantized_input_stats"][idx]
+ input_array.name = name
+ input_array.shape.dims.extend(map(int, shape))
+
+ for name in output_arrays:
+ model_flags.output_arrays.append(name)
+
+ data = toco_convert_protos(model_flags.SerializeToString(),
+ toco_flags.SerializeToString(),
+ input_data.SerializeToString())
+ return data
+
+
def toco_convert_impl(input_data, input_tensors, output_tensors, *args,
**kwargs):
""""Convert a model using TOCO.
diff --git a/tensorflow/contrib/lite/python/convert_test.py b/tensorflow/contrib/lite/python/convert_test.py
index bc05514cec..59f537b82a 100644
--- a/tensorflow/contrib/lite/python/convert_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.lite.python import convert
from tensorflow.contrib.lite.python import lite_constants
from tensorflow.contrib.lite.python import op_hint
+from tensorflow.contrib.lite.python.interpreter import Interpreter
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -37,9 +40,12 @@ class ConvertTest(test_util.TensorFlowTestCase):
dtype=dtypes.float32)
out_tensor = in_tensor + in_tensor
sess = session.Session()
+
# Try running on valid graph
- result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
- self.assertTrue(result)
+ tflite_model = convert.toco_convert(sess.graph_def, [in_tensor],
+ [out_tensor])
+ self.assertTrue(tflite_model)
+
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
@@ -52,11 +58,85 @@ class ConvertTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = convert.toco_convert(
+
+ tflite_model = convert.toco_convert(
sess.graph_def, [in_tensor], [out_tensor],
inference_type=lite_constants.QUANTIZED_UINT8,
quantized_input_stats=[(0., 1.)])
- self.assertTrue(result)
+ self.assertTrue(tflite_model)
+
+ def testGraphDefBasic(self):
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="input")
+ _ = in_tensor + in_tensor
+ sess = session.Session()
+
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def, [("input", [1, 16, 16, 3])], ["add"],
+ inference_type=lite_constants.FLOAT)
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual("input", input_details[0]["name"])
+ self.assertEqual(np.float32, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), input_details[0]["quantization"])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("add", output_details[0]["name"])
+ self.assertEqual(np.float32, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertEqual((0., 0.), output_details[0]["quantization"])
+
+ def testGraphDefQuantization(self):
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputA")
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name="inputB")
+ _ = array_ops.fake_quant_with_min_max_args(
+ in_tensor_1 + in_tensor_2, min=0., max=1., name="output")
+ sess = session.Session()
+
+ input_arrays_map = [("inputA", [1, 16, 16, 3]), ("inputB", [1, 16, 16, 3])]
+ output_arrays = ["output"]
+ tflite_model = convert.toco_convert_graph_def(
+ sess.graph_def,
+ input_arrays_map,
+ output_arrays,
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.), (0., 1.)])
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(2, len(input_details))
+ self.assertEqual("inputA", input_details[0]["name"])
+ self.assertEqual(np.uint8, input_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[0]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[0]["quantization"]) # scale, zero_point
+
+ self.assertEqual("inputB", input_details[1]["name"])
+ self.assertEqual(np.uint8, input_details[1]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == input_details[1]["shape"]).all())
+ self.assertEqual((1., 0.),
+ input_details[1]["quantization"]) # scale, zero_point
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(1, len(output_details))
+ self.assertEqual("output", output_details[0]["name"])
+ self.assertEqual(np.uint8, output_details[0]["dtype"])
+ self.assertTrue(([1, 16, 16, 3] == output_details[0]["shape"]).all())
+ self.assertTrue(output_details[0]["quantization"][0] > 0) # scale
class ConvertTestOpHint(test_util.TensorFlowTestCase):
@@ -243,7 +323,6 @@ class ConvertTestOpHint(test_util.TensorFlowTestCase):
with self.test_session() as sess:
stubbed_graphdef = op_hint.convert_op_hints_to_stubs(
graph_def=sess.graph_def)
- print(stubbed_graphdef)
self.assertCountEqual(
self._getGraphOpTypes(
stubbed_graphdef,
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 2313bfa3b6..a4c9a2381c 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -42,6 +42,7 @@ from tensorflow.contrib.lite.python import lite_constants as constants
from tensorflow.contrib.lite.python.convert import build_toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert import tensor_name as _tensor_name
from tensorflow.contrib.lite.python.convert import toco_convert # pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert_graph_def as _toco_convert_graph_def
from tensorflow.contrib.lite.python.convert import toco_convert_impl as _toco_convert_impl
from tensorflow.contrib.lite.python.convert import toco_convert_protos # pylint: disable=unused-import
from tensorflow.contrib.lite.python.convert_saved_model import freeze_saved_model as _freeze_saved_model
@@ -55,6 +56,7 @@ from tensorflow.python import keras as _keras
from tensorflow.python.client import session as _session
from tensorflow.python.framework import graph_util as _tf_graph_util
from tensorflow.python.framework import ops as _ops
+from tensorflow.python.framework.errors_impl import NotFoundError as _NotFoundError
from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
from tensorflow.python.saved_model import signature_constants as _signature_constants
from tensorflow.python.saved_model import tag_constants as _tag_constants
@@ -133,7 +135,12 @@ class TocoConverter(object):
```
"""
- def __init__(self, graph_def, input_tensors, output_tensors):
+ def __init__(self,
+ graph_def,
+ input_tensors,
+ output_tensors,
+ input_arrays_with_shape=None,
+ output_arrays=None):
"""Constructor for TocoConverter.
Args:
@@ -142,6 +149,17 @@ class TocoConverter(object):
input_tensors: List of input tensors. Type and shape are computed using
`foo.get_shape()` and `foo.dtype`.
output_tensors: List of output tensors (only .name is used from this).
+ input_arrays_with_shape: Tuple of strings representing input tensor names
+ and list of integers representing input shapes
+ (e.g., [("foo" : [1, 16, 16, 3])]). Use only when graph cannot be loaded
+ into TensorFlow and when `input_tensors` and `output_tensors` are None.
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Use only when
+ graph cannot be loaded into TensorFlow and when `input_tensors` and
+ `output_tensors` are None. (default None)
+
+ Raises:
+ ValueError: Invalid arguments.
"""
self._graph_def = graph_def
self._input_tensors = input_tensors
@@ -159,6 +177,15 @@ class TocoConverter(object):
self.dump_graphviz_dir = None
self.dump_graphviz_video = False
+ # Attributes are used by models that cannot be loaded into TensorFlow.
+ if not self._has_valid_tensors():
+ if not input_arrays_with_shape or not output_arrays:
+ raise ValueError(
+ "If input_tensors and output_tensors are None, both "
+ "input_arrays_with_shape and output_arrays must be defined.")
+ self._input_arrays_with_shape = input_arrays_with_shape
+ self._output_arrays = output_arrays
+
@classmethod
def from_session(cls, sess, input_tensors, output_tensors):
"""Creates a TocoConverter class from a TensorFlow Session.
@@ -200,6 +227,7 @@ class TocoConverter(object):
Unable to parse input file.
The graph is not frozen.
input_arrays or output_arrays contains an invalid tensor name.
+ input_shapes is not correctly defined when required
"""
with _ops.Graph().as_default():
with _session.Session() as sess:
@@ -222,20 +250,44 @@ class TocoConverter(object):
except (_text_format.ParseError, DecodeError):
raise ValueError(
"Unable to parse input file '{}'.".format(graph_def_file))
- _import_graph_def(graph_def, name="")
-
- # Get input and output tensors.
- input_tensors = _get_tensors_from_tensor_names(sess.graph, input_arrays)
- output_tensors = _get_tensors_from_tensor_names(sess.graph,
- output_arrays)
- _set_tensor_shapes(input_tensors, input_shapes)
-
- # Check if graph is frozen.
- if not _is_frozen_graph(sess):
- raise ValueError("Please freeze the graph using freeze_graph.py.")
- # Create TocoConverter class.
- return cls(sess.graph_def, input_tensors, output_tensors)
+ # Handles models with custom TFLite ops that cannot be resolved in
+ # TensorFlow.
+ load_model_in_session = True
+ try:
+ _import_graph_def(graph_def, name="")
+ except _NotFoundError:
+ load_model_in_session = False
+
+ if load_model_in_session:
+ # Check if graph is frozen.
+ if not _is_frozen_graph(sess):
+ raise ValueError("Please freeze the graph using freeze_graph.py.")
+
+ # Get input and output tensors.
+ input_tensors = _get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ output_tensors = _get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+ _set_tensor_shapes(input_tensors, input_shapes)
+
+ return cls(sess.graph_def, input_tensors, output_tensors)
+ else:
+ if not input_shapes:
+ raise ValueError("input_shapes must be defined for this model.")
+ if set(input_arrays) != set(input_shapes.keys()):
+ raise ValueError("input_shapes must contain a value for each item "
+ "in input_array.")
+
+ input_arrays_with_shape = [
+ (name, input_shapes[name]) for name in input_arrays
+ ]
+ return cls(
+ graph_def,
+ input_tensors=None,
+ output_tensors=None,
+ input_arrays_with_shape=input_arrays_with_shape,
+ output_arrays=output_arrays)
@classmethod
def from_saved_model(cls,
@@ -330,25 +382,25 @@ class TocoConverter(object):
None value for dimension in input_tensor.
"""
# Checks dimensions in input tensor.
- for tensor in self._input_tensors:
- if not tensor.get_shape():
- raise ValueError("Provide an input shape for input array '{0}'.".format(
- _tensor_name(tensor)))
- shape = tensor.get_shape().as_list()
- if None in shape[1:]:
- raise ValueError(
- "None is only supported in the 1st dimension. Tensor '{0}' has "
- "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
- elif shape[0] is None:
- self._set_batch_size(batch_size=1)
+ if self._has_valid_tensors():
+ for tensor in self._input_tensors:
+ if not tensor.get_shape():
+ raise ValueError("Provide an input shape for input array "
+ "'{0}'.".format(_tensor_name(tensor)))
+ shape = tensor.get_shape().as_list()
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(_tensor_name(tensor), shape))
+ elif shape[0] is None:
+ self._set_batch_size(batch_size=1)
# Get quantization stats. Ensures there is one stat per name if the stats
# are specified.
if self.quantized_input_stats:
quantized_stats = []
invalid_stats = []
- for tensor in self._input_tensors:
- name = _tensor_name(tensor)
+ for name in self.get_input_arrays():
if name in self.quantized_input_stats:
quantized_stats.append(self.quantized_input_stats[name])
else:
@@ -360,24 +412,35 @@ class TocoConverter(object):
else:
quantized_stats = None
+ converter_kwargs = {
+ "inference_type": self.inference_type,
+ "inference_input_type": self.inference_input_type,
+ "input_format": constants.TENSORFLOW_GRAPHDEF,
+ "output_format": self.output_format,
+ "quantized_input_stats": quantized_stats,
+ "default_ranges_stats": self.default_ranges_stats,
+ "drop_control_dependency": self.drop_control_dependency,
+ "reorder_across_fake_quant": self.reorder_across_fake_quant,
+ "change_concat_input_ranges": self.change_concat_input_ranges,
+ "allow_custom_ops": self.allow_custom_ops,
+ "quantize_weights": self.quantize_weights,
+ "dump_graphviz_dir": self.dump_graphviz_dir,
+ "dump_graphviz_video": self.dump_graphviz_video
+ }
+
# Converts model.
- result = _toco_convert_impl(
- input_data=self._graph_def,
- input_tensors=self._input_tensors,
- output_tensors=self._output_tensors,
- inference_type=self.inference_type,
- inference_input_type=self.inference_input_type,
- input_format=constants.TENSORFLOW_GRAPHDEF,
- output_format=self.output_format,
- quantized_input_stats=quantized_stats,
- default_ranges_stats=self.default_ranges_stats,
- drop_control_dependency=self.drop_control_dependency,
- reorder_across_fake_quant=self.reorder_across_fake_quant,
- change_concat_input_ranges=self.change_concat_input_ranges,
- allow_custom_ops=self.allow_custom_ops,
- quantize_weights=self.quantize_weights,
- dump_graphviz_dir=self.dump_graphviz_dir,
- dump_graphviz_video=self.dump_graphviz_video)
+ if self._has_valid_tensors():
+ result = _toco_convert_impl(
+ input_data=self._graph_def,
+ input_tensors=self._input_tensors,
+ output_tensors=self._output_tensors,
+ **converter_kwargs)
+ else:
+ result = _toco_convert_graph_def(
+ input_data=self._graph_def,
+ input_arrays_with_shape=self._input_arrays_with_shape,
+ output_arrays=self._output_arrays,
+ **converter_kwargs)
return result
def get_input_arrays(self):
@@ -386,7 +449,18 @@ class TocoConverter(object):
Returns:
List of strings.
"""
- return [_tensor_name(tensor) for tensor in self._input_tensors]
+ if self._has_valid_tensors():
+ return [_tensor_name(tensor) for tensor in self._input_tensors]
+ else:
+ return [name for name, _ in self._input_arrays_with_shape]
+
+ def _has_valid_tensors(self):
+ """Checks if the input and output tensors have been initialized.
+
+ Returns:
+ Bool.
+ """
+ return self._input_tensors and self._output_tensors
def _set_batch_size(self, batch_size):
"""Sets the first dimension of the input tensor to `batch_size`.
@@ -394,7 +468,14 @@ class TocoConverter(object):
Args:
batch_size: Batch size for the model. Replaces the first dimension of an
input size array if undefined. (default 1)
+
+ Raises:
+ ValueError: input_tensor is not defined.
"""
+ if not self._has_valid_tensors():
+ raise ValueError("The batch size cannot be set for this model. Please "
+ "use input_shapes parameter.")
+
for tensor in self._input_tensors:
shape = tensor.get_shape().as_list()
shape[0] = batch_size
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/lite_test.py
index 2f13684228..8c9cfa943f 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/lite_test.py
@@ -35,11 +35,51 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.variables import global_variables_initializer as _global_variables_initializer
from tensorflow.python.platform import gfile
+from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training.training_util import write_graph
+class FromConstructor(test_util.TensorFlowTestCase):
+
+ # Tests invalid constructors using a dummy value for the GraphDef.
+ def testInvalidConstructor(self):
+ message = ('If input_tensors and output_tensors are None, both '
+ 'input_arrays_with_shape and output_arrays must be defined.')
+
+ # `output_arrays` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(
+ None, None, [], input_arrays_with_shape=[('input', [3, 9])])
+ self.assertEqual(message, str(error.exception))
+
+ # `input_arrays_with_shape` is not defined.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter(None, [], None, output_arrays=['output'])
+ self.assertEqual(message, str(error.exception))
+
+ # Tests valid constructors using a dummy value for the GraphDef.
+ def testValidConstructor(self):
+ converter = lite.TocoConverter(
+ None,
+ None,
+ None,
+ input_arrays_with_shape=[('input', [3, 9])],
+ output_arrays=['output'])
+ self.assertFalse(converter._has_valid_tensors())
+ self.assertEqual(converter.get_input_arrays(), ['input'])
+
+ with self.assertRaises(ValueError) as error:
+ converter._set_batch_size(1)
+ self.assertEqual(
+ 'The batch size cannot be set for this model. Please use '
+ 'input_shapes parameter.', str(error.exception))
+
+ converter = lite.TocoConverter(None, ['input_tensor'], ['output_tensor'])
+ self.assertTrue(converter._has_valid_tensors())
+
+
class FromSessionTest(test_util.TensorFlowTestCase):
def testFloat(self):
@@ -490,6 +530,79 @@ class FromFrozenGraphFile(test_util.TensorFlowTestCase):
'Unable to parse input file \'{}\'.'.format(graph_def_file),
str(error.exception))
+ # TODO(nupurgarg): Test model loading in open source.
+ def _initObjectDetectionArgs(self):
+ # Initializes the arguments required for the object detection model.
+ self._graph_def_file = resource_loader.get_path_to_datafile(
+ 'testdata/tflite_graph.pbtxt')
+ self._input_arrays = ['normalized_input_image_tensor']
+ self._output_arrays = [
+ 'TFLite_Detection_PostProcess', 'TFLite_Detection_PostProcess:1',
+ 'TFLite_Detection_PostProcess:2', 'TFLite_Detection_PostProcess:3'
+ ]
+ self._input_shapes = {'normalized_input_image_tensor': [1, 300, 300, 3]}
+
+ def testTFLiteGraphDef(self):
+ # Tests the object detection model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ converter = lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays,
+ self._input_shapes)
+ converter.allow_custom_ops = True
+ tflite_model = converter.convert()
+ self.assertTrue(tflite_model)
+
+ # Check values from converted model.
+ interpreter = Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ self.assertEqual(1, len(input_details))
+ self.assertEqual('normalized_input_image_tensor', input_details[0]['name'])
+ self.assertEqual(np.float32, input_details[0]['dtype'])
+ self.assertTrue(([1, 300, 300, 3] == input_details[0]['shape']).all())
+ self.assertEqual((0., 0.), input_details[0]['quantization'])
+
+ output_details = interpreter.get_output_details()
+ self.assertEqual(4, len(output_details))
+ self.assertEqual('TFLite_Detection_PostProcess', output_details[0]['name'])
+ self.assertEqual(np.float32, output_details[0]['dtype'])
+ self.assertTrue(([1, 10, 4] == output_details[0]['shape']).all())
+ self.assertEqual((0., 0.), output_details[0]['quantization'])
+
+ self.assertEqual('TFLite_Detection_PostProcess:1',
+ output_details[1]['name'])
+ self.assertTrue(([1, 10] == output_details[1]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:2',
+ output_details[2]['name'])
+ self.assertTrue(([1, 10] == output_details[2]['shape']).all())
+ self.assertEqual('TFLite_Detection_PostProcess:3',
+ output_details[3]['name'])
+ self.assertTrue(([1] == output_details[3]['shape']).all())
+
+ def testTFLiteGraphDefInvalid(self):
+ # Tests invalid cases for the model that cannot be loaded in TensorFlow.
+ self._initObjectDetectionArgs()
+
+ # Missing `input_shapes`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file, self._input_arrays, self._output_arrays)
+ self.assertEqual('input_shapes must be defined for this model.',
+ str(error.exception))
+
+ # `input_shapes` does not contain the names in `input_arrays`.
+ with self.assertRaises(ValueError) as error:
+ lite.TocoConverter.from_frozen_graph(
+ self._graph_def_file,
+ self._input_arrays,
+ self._output_arrays,
+ input_shapes={'invalid-value': [1, 19]})
+ self.assertEqual(
+ 'input_shapes must contain a value for each item in input_array.',
+ str(error.exception))
+
class FromSavedModelTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/python/tflite_convert.py b/tensorflow/contrib/lite/python/tflite_convert.py
index 46bdb3e553..ce12a9abde 100644
--- a/tensorflow/contrib/lite/python/tflite_convert.py
+++ b/tensorflow/contrib/lite/python/tflite_convert.py
@@ -132,7 +132,8 @@ def _convert_model(flags):
if flags.reorder_across_fake_quant:
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
if flags.change_concat_input_ranges:
- converter.change_concat_input_ranges = flags.change_concat_input_ranges
+ converter.change_concat_input_ranges = (
+ flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
if flags.quantize_weights:
@@ -333,9 +334,14 @@ def run_main(_):
"the graph. Results in a graph that differs from the quantized "
"training graph, potentially causing differing arithmetic "
"behavior. (default False)"))
+ # Usage for this flag is --change_concat_input_ranges=true or
+ # --change_concat_input_ranges=false in order to make it clear what the flag
+ # is set to. This keeps the usage consistent with other usages of the flag
+ # where the default is different. The default value here is False.
parser.add_argument(
"--change_concat_input_ranges",
- action="store_true",
+ type=str.upper,
+ choices=["TRUE", "FALSE"],
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 941b27cb59..de38f8c0c2 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -767,6 +767,7 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
],
build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
)
+
tf_http_archive(
name = "tflite_mobilenet_ssd_quant",
sha256 = "a809cd290b4d6a2e8a9d5dad076e0bd695b8091974e0eed1052b480b2f21b6dc",
@@ -778,6 +779,17 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
)
tf_http_archive(
+ name = "tflite_mobilenet_ssd_quant_protobuf",
+ sha256 = "09280972c5777f1aa775ef67cb4ac5d5ed21970acd8535aeca62450ef14f0d79",
+ urls = [
+ "https://mirror.bazel.build/storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ "http://storage.googleapis.com/download.tensorflow.org/models/object_detection/ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18.tar.gz",
+ ],
+ strip_prefix = "ssd_mobilenet_v1_quantized_300x300_coco14_sync_2018_07_18",
+ build_file = str(Label("//third_party:tflite_mobilenet.BUILD")),
+ )
+
+ tf_http_archive(
name = "tflite_conv_actions_frozen",
sha256 = "d947b38cba389b5e2d0bfc3ea6cc49c784e187b41a071387b3742d1acac7691e",
urls = [