aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-10-03 10:51:17 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 10:55:31 -0700
commit560624bff65b7b502da2c52f9b250d9181c4a3f7 (patch)
tree29d3aab2396c231223952515333ce2f2c08f8e30 /tensorflow/contrib/lite
parentaf1458a9c1a3bc8d49a1e55386950b4941ab1815 (diff)
Internal change.
PiperOrigin-RevId: 215589009
Diffstat (limited to 'tensorflow/contrib/lite')
-rw-r--r--tensorflow/contrib/lite/python/interpreter.py17
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc19
-rw-r--r--tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h1
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py81
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py38
5 files changed, 147 insertions, 9 deletions
diff --git a/tensorflow/contrib/lite/python/interpreter.py b/tensorflow/contrib/lite/python/interpreter.py
index 5700bf7892..6300552cbe 100644
--- a/tensorflow/contrib/lite/python/interpreter.py
+++ b/tensorflow/contrib/lite/python/interpreter.py
@@ -129,6 +129,23 @@ class Interpreter(object):
return details
+ def get_tensor_details(self):
+ """Gets tensor details for every tensor with valid tensor details.
+
+ Tensors where required information about the tensor is not found are not
+ added to the list. This includes temporary tensors without a name.
+
+ Returns:
+ A list of dictionaries containing tensor information.
+ """
+ tensor_details = []
+ for idx in range(self._interpreter.NumTensors()):
+ try:
+ tensor_details.append(self._get_tensor_details(idx))
+ except ValueError:
+ pass
+ return tensor_details
+
def get_input_details(self):
"""Gets model input details.
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
index 418f19a179..1e2384b6d2 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.cc
@@ -277,13 +277,20 @@ PyObject* InterpreterWrapper::ResizeInputTensor(int i, PyObject* value) {
Py_RETURN_NONE;
}
+int InterpreterWrapper::NumTensors() const {
+ if (!interpreter_) {
+ return 0;
+ }
+ return interpreter_->tensors_size();
+}
+
std::string InterpreterWrapper::TensorName(int i) const {
if (!interpreter_ || i >= interpreter_->tensors_size() || i < 0) {
return "";
}
const TfLiteTensor* tensor = interpreter_->tensor(i);
- return tensor->name;
+ return tensor->name ? tensor->name : "";
}
PyObject* InterpreterWrapper::TensorType(int i) const {
@@ -291,6 +298,11 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->type == kTfLiteNoType) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no type found.");
+ return nullptr;
+ }
+
int code = TfLiteTypeToPyArrayType(tensor->type);
if (code == -1) {
PyErr_Format(PyExc_ValueError, "Invalid tflite type code %d", code);
@@ -302,7 +314,12 @@ PyObject* InterpreterWrapper::TensorType(int i) const {
PyObject* InterpreterWrapper::TensorSize(int i) const {
TFLITE_PY_ENSURE_VALID_INTERPRETER();
TFLITE_PY_TENSOR_BOUNDS_CHECK(i);
+
const TfLiteTensor* tensor = interpreter_->tensor(i);
+ if (tensor->dims == nullptr) {
+ PyErr_Format(PyExc_ValueError, "Tensor with no shape found.");
+ return nullptr;
+ }
PyObject* np_array =
PyArrayFromIntVector(tensor->dims->data, tensor->dims->size);
diff --git a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
index f5ca81e62a..b98046fe8a 100644
--- a/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
+++ b/tensorflow/contrib/lite/python/interpreter_wrapper/interpreter_wrapper.h
@@ -59,6 +59,7 @@ class InterpreterWrapper {
PyObject* OutputIndices() const;
PyObject* ResizeInputTensor(int i, PyObject* value);
+ int NumTensors() const;
std::string TensorName(int i) const;
PyObject* TensorType(int i) const;
PyObject* TensorSize(int i) const;
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
index 5ca57d083d..72029ed03c 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -35,9 +35,9 @@ def _convert(converter, **kwargs):
"""Converts the model.
Args:
- converter: TocoConverter object.
+ converter: TFLiteConverter object.
**kwargs: Additional arguments to be passed into the converter. Supported
- flags are {"converter_mode", "post_training_quant"}.
+ flags are {"converter_mode", "post_training_quantize"}.
Returns:
The converted TFLite model in serialized format.
@@ -174,7 +174,7 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
tflite_model: Serialized TensorFlow Lite model.
tf_eval_func: Lambda function that takes in input data and outputs the
results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
- tolerance: Decimal place to check accuracy to.
+ tolerance: Decimal place to check accuracy to. (default 5)
"""
input_data = _generate_random_input_data(tflite_model)
tf_results = tf_eval_func(input_data)
@@ -183,6 +183,71 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+def test_frozen_graph_quant(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
+ """Sanity check to validate post quantize flag alters the graph.
+
+ This test does not check correctness of the converted model. It converts the
+ TensorFlow frozen graph to TFLite with and without the post_training_quantized
+ flag. It ensures some tensors have different types between the float and
+ quantized models in the case of an all TFLite model or mix-and-match model.
+ It ensures tensor types do not change in the case of an all Flex model.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ **kwargs: Additional arguments to be passed into the converter.
+
+ Raises:
+ ValueError: post_training_quantize flag doesn't act as intended.
+ """
+ # Convert and load the float model.
+ converter = _lite.TFLiteConverter.from_frozen_graph(
+ filename, input_arrays, output_arrays, input_shapes)
+ tflite_model_float = _convert(converter, **kwargs)
+
+ interpreter_float = _lite.Interpreter(model_content=tflite_model_float)
+ interpreter_float.allocate_tensors()
+ float_tensors = interpreter_float.get_tensor_details()
+
+ # Convert and load the quantized model.
+ converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays,
+ output_arrays)
+ tflite_model_quant = _convert(
+ converter, post_training_quantize=True, **kwargs)
+
+ interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant)
+ interpreter_quant.allocate_tensors()
+ quant_tensors = interpreter_quant.get_tensor_details()
+ quant_tensors_map = {
+ tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors
+ }
+
+ # Check if weights are of different types in the float and quantized models.
+ num_tensors_float = len(float_tensors)
+ num_tensors_same_dtypes = sum(
+ float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"]
+ for float_tensor in float_tensors)
+ has_quant_tensor = num_tensors_float != num_tensors_same_dtypes
+
+ if ("converter_mode" in kwargs and
+ kwargs["converter_mode"] == _lite.ConverterMode.TOCO_FLEX_ALL):
+ if has_quant_tensor:
+ raise ValueError("--post_training_quantize flag unexpectedly altered the "
+ "full Flex mode graph.")
+ elif not has_quant_tensor:
+ raise ValueError("--post_training_quantize flag was unable to quantize the "
+ "graph as expected in TFLite and mix-and-match mode.")
+
+
def test_frozen_graph(filename,
input_arrays,
output_arrays,
@@ -203,8 +268,8 @@ def test_frozen_graph(filename,
(default None)
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
- output_arrays, input_shapes)
+ converter = _lite.TFLiteConverter.from_frozen_graph(
+ filename, input_arrays, output_arrays, input_shapes)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
@@ -224,8 +289,8 @@ def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
signature_key: Key identifying SignatureDef containing inputs and outputs.
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_saved_model(directory, tag_set,
- signature_key)
+ converter = _lite.TFLiteConverter.from_saved_model(directory, tag_set,
+ signature_key)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
@@ -242,7 +307,7 @@ def test_keras_model(filename, **kwargs):
filename: Full filepath of HDF5 file containing the tf.keras model.
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TocoConverter.from_keras_model_file(filename)
+ converter = _lite.TFLiteConverter.from_keras_model_file(filename)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_keras_model(filename)
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
index 1498f86c6f..e07202b1a6 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import os
import tempfile
+import numpy as np
from tensorflow.contrib.lite.python import lite
from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage
@@ -66,6 +67,43 @@ class EvaluateFrozenGraph(test.TestCase):
model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
['add', 'Mean'])
+ def _getQuantizedModel(self):
+ np.random.seed(0)
+ with session.Session().as_default() as sess:
+ # The tensor needs to have more than 1024 elements for quantize_weights to
+ # kick in. Thus, the [33, 33] shape.
+ in_tensor_1 = array_ops.placeholder(
+ shape=[33, 33], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = constant_op.constant(
+ np.random.uniform(low=-10., high=10., size=(33, 33)),
+ shape=[33, 33],
+ dtype=dtypes.float32,
+ name='inputB')
+ _ = math_ops.matmul(in_tensor_1, in_tensor_2, name='output')
+
+ filename = self._saveFrozenGraph(sess)
+ return filename
+
+ def testQuantized(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(filename, ['inputA', 'inputB'],
+ ['output'])
+
+ def testQuantizedInputShapes(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(
+ filename, ['inputA', 'inputB'], ['output'],
+ input_shapes={
+ 'inputA': [33, 33],
+ 'inputB': [33, 33],
+ })
+
+ def testQuantizedFlexAll(self):
+ filename = self._getQuantizedModel()
+ model_coverage.test_frozen_graph_quant(
+ filename, ['inputA', 'inputB'], ['output'],
+ converter_mode=lite.ConverterMode.TOCO_FLEX_ALL)
+
class EvaluateSavedModel(test.TestCase):