aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/testing
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-27 11:11:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 11:16:47 -0700
commitdb3e59a545f06780583ad839da9e19d847dfd392 (patch)
tree04e32a287af30dbc4b3530115e633407bd756286 /tensorflow/contrib/lite/testing
parent50b94fa1d50a916eaf7a5a46d93260e9b0f93554 (diff)
Internal change.
PiperOrigin-RevId: 214804105
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
index f8ab394c60..5ca57d083d 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -183,7 +183,11 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
-def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs):
+def test_frozen_graph(filename,
+ input_arrays,
+ output_arrays,
+ input_shapes=None,
+ **kwargs):
"""Validates the TensorFlow frozen graph converts to a TFLite model.
Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
@@ -193,10 +197,14 @@ def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs):
filename: Full filepath of file containing frozen GraphDef.
input_arrays: List of input tensors to freeze graph with.
output_arrays: List of output tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
**kwargs: Additional arguments to be passed into the converter.
"""
converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
- output_arrays)
+ output_arrays, input_shapes)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)