diff options
author | 2018-10-03 10:51:17 -0700 | |
---|---|---|
committer | 2018-10-03 10:55:31 -0700 | |
commit | 560624bff65b7b502da2c52f9b250d9181c4a3f7 (patch) | |
tree | 29d3aab2396c231223952515333ce2f2c08f8e30 /tensorflow/contrib/lite/testing | |
parent | af1458a9c1a3bc8d49a1e55386950b4941ab1815 (diff) |
Internal change.
PiperOrigin-RevId: 215589009
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r-- | tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py | 81 | ||||
-rw-r--r-- | tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py | 38 |
2 files changed, 111 insertions, 8 deletions
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py index 5ca57d083d..72029ed03c 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py @@ -35,9 +35,9 @@ def _convert(converter, **kwargs): """Converts the model. Args: - converter: TocoConverter object. + converter: TFLiteConverter object. **kwargs: Additional arguments to be passed into the converter. Supported - flags are {"converter_mode", "post_training_quant"}. + flags are {"converter_mode", "post_training_quantize"}. Returns: The converted TFLite model in serialized format. @@ -174,7 +174,7 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5): tflite_model: Serialized TensorFlow Lite model. tf_eval_func: Lambda function that takes in input data and outputs the results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]). - tolerance: Decimal place to check accuracy to. + tolerance: Decimal place to check accuracy to. (default 5) """ input_data = _generate_random_input_data(tflite_model) tf_results = tf_eval_func(input_data) @@ -183,6 +183,71 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5): np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) +def test_frozen_graph_quant(filename, + input_arrays, + output_arrays, + input_shapes=None, + **kwargs): + """Sanity check to validate post quantize flag alters the graph. + + This test does not check correctness of the converted model. It converts the + TensorFlow frozen graph to TFLite with and without the post_training_quantized + flag. It ensures some tensors have different types between the float and + quantized models in the case of an all TFLite model or mix-and-match model. + It ensures tensor types do not change in the case of an all Flex model. + + Args: + filename: Full filepath of file containing frozen GraphDef. + input_arrays: List of input tensors to freeze graph with. + output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) + **kwargs: Additional arguments to be passed into the converter. + + Raises: + ValueError: post_training_quantize flag doesn't act as intended. + """ + # Convert and load the float model. + converter = _lite.TFLiteConverter.from_frozen_graph( + filename, input_arrays, output_arrays, input_shapes) + tflite_model_float = _convert(converter, **kwargs) + + interpreter_float = _lite.Interpreter(model_content=tflite_model_float) + interpreter_float.allocate_tensors() + float_tensors = interpreter_float.get_tensor_details() + + # Convert and load the quantized model. + converter = _lite.TFLiteConverter.from_frozen_graph(filename, input_arrays, + output_arrays) + tflite_model_quant = _convert( + converter, post_training_quantize=True, **kwargs) + + interpreter_quant = _lite.Interpreter(model_content=tflite_model_quant) + interpreter_quant.allocate_tensors() + quant_tensors = interpreter_quant.get_tensor_details() + quant_tensors_map = { + tensor_detail["name"]: tensor_detail for tensor_detail in quant_tensors + } + + # Check if weights are of different types in the float and quantized models. + num_tensors_float = len(float_tensors) + num_tensors_same_dtypes = sum( + float_tensor["dtype"] == quant_tensors_map[float_tensor["name"]]["dtype"] + for float_tensor in float_tensors) + has_quant_tensor = num_tensors_float != num_tensors_same_dtypes + + if ("converter_mode" in kwargs and + kwargs["converter_mode"] == _lite.ConverterMode.TOCO_FLEX_ALL): + if has_quant_tensor: + raise ValueError("--post_training_quantize flag unexpectedly altered the " + "full Flex mode graph.") + elif not has_quant_tensor: + raise ValueError("--post_training_quantize flag was unable to quantize the " + "graph as expected in TFLite and mix-and-match mode.") + + def test_frozen_graph(filename, input_arrays, output_arrays, @@ -203,8 +268,8 @@ def test_frozen_graph(filename, (default None) **kwargs: Additional arguments to be passed into the converter. """ - converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays, - output_arrays, input_shapes) + converter = _lite.TFLiteConverter.from_frozen_graph( + filename, input_arrays, output_arrays, input_shapes) tflite_model = _convert(converter, **kwargs) tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays) @@ -224,8 +289,8 @@ def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs): signature_key: Key identifying SignatureDef containing inputs and outputs. **kwargs: Additional arguments to be passed into the converter. """ - converter = _lite.TocoConverter.from_saved_model(directory, tag_set, - signature_key) + converter = _lite.TFLiteConverter.from_saved_model(directory, tag_set, + signature_key) tflite_model = _convert(converter, **kwargs) tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key) @@ -242,7 +307,7 @@ def test_keras_model(filename, **kwargs): filename: Full filepath of HDF5 file containing the tf.keras model. **kwargs: Additional arguments to be passed into the converter. """ - converter = _lite.TocoConverter.from_keras_model_file(filename) + converter = _lite.TFLiteConverter.from_keras_model_file(filename) tflite_model = _convert(converter, **kwargs) tf_eval_func = evaluate_keras_model(filename) diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py index 1498f86c6f..e07202b1a6 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py +++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py @@ -20,6 +20,7 @@ from __future__ import print_function import os import tempfile +import numpy as np from tensorflow.contrib.lite.python import lite from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage @@ -66,6 +67,43 @@ class EvaluateFrozenGraph(test.TestCase): model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'], ['add', 'Mean']) + def _getQuantizedModel(self): + np.random.seed(0) + with session.Session().as_default() as sess: + # The tensor needs to have more than 1024 elements for quantize_weights to + # kick in. Thus, the [33, 33] shape. + in_tensor_1 = array_ops.placeholder( + shape=[33, 33], dtype=dtypes.float32, name='inputA') + in_tensor_2 = constant_op.constant( + np.random.uniform(low=-10., high=10., size=(33, 33)), + shape=[33, 33], + dtype=dtypes.float32, + name='inputB') + _ = math_ops.matmul(in_tensor_1, in_tensor_2, name='output') + + filename = self._saveFrozenGraph(sess) + return filename + + def testQuantized(self): + filename = self._getQuantizedModel() + model_coverage.test_frozen_graph_quant(filename, ['inputA', 'inputB'], + ['output']) + + def testQuantizedInputShapes(self): + filename = self._getQuantizedModel() + model_coverage.test_frozen_graph_quant( + filename, ['inputA', 'inputB'], ['output'], + input_shapes={ + 'inputA': [33, 33], + 'inputB': [33, 33], + }) + + def testQuantizedFlexAll(self): + filename = self._getQuantizedModel() + model_coverage.test_frozen_graph_quant( + filename, ['inputA', 'inputB'], ['output'], + converter_mode=lite.ConverterMode.TOCO_FLEX_ALL) + class EvaluateSavedModel(test.TestCase): |