aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py')
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
index f8ab394c60..5ca57d083d 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -183,7 +183,11 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
-def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs):
+def test_frozen_graph(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
"""Validates the TensorFlow frozen graph converts to a TFLite model.
Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
@@ -193,10 +197,14 @@ def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs):
filename: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
**kwargs: Additional arguments to be passed into the converter.
"""
converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
- output_arrays)
+ output_arrays, input_shapes)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)