diff options
author | 2018-09-27 11:11:34 -0700 | |
---|---|---|
committer | 2018-09-27 11:16:47 -0700 | |
commit | db3e59a545f06780583ad839da9e19d847dfd392 (patch) | |
tree | 04e32a287af30dbc4b3530115e633407bd756286 /tensorflow/contrib/lite/testing | |
parent | 50b94fa1d50a916eaf7a5a46d93260e9b0f93554 (diff) |
Internal change.
PiperOrigin-RevId: 214804105
Diffstat (limited to 'tensorflow/contrib/lite/testing')
-rw-r--r-- | tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py | 12 |
1 files changed, 10 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py index f8ab394c60..5ca57d083d 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py @@ -183,7 +183,11 @@ def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5): np.testing.assert_almost_equal(tf_result, tflite_result, tolerance) -def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs): +def test_frozen_graph(filename, + input_arrays, + output_arrays, + input_shapes=None, + **kwargs): """Validates the TensorFlow frozen graph converts to a TFLite model. Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the @@ -193,10 +197,14 @@ def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs): filename: Full filepath of file containing frozen GraphDef. input_arrays: List of input tensors to freeze graph with. output_arrays: List of output tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) **kwargs: Additional arguments to be passed into the converter. """ converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays, - output_arrays) + output_arrays, input_shapes) tflite_model = _convert(converter, **kwargs) tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays) |