From 3e8af7ea6b70104b05be22797451d0218c9e5262 Mon Sep 17 00:00:00 2001 From: Nupur Garg Date: Tue, 9 Oct 2018 10:58:03 -0700 Subject: Internal change. PiperOrigin-RevId: 216385202 --- .../contrib/lite/testing/model_coverage/model_coverage_lib.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'tensorflow/contrib') diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py index 72029ed03c..ab29f71138 100644 --- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py +++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py @@ -297,7 +297,7 @@ def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs): compare_models_random_data(tflite_model, tf_eval_func) -def test_keras_model(filename, **kwargs): +def test_keras_model(filename, input_arrays=None, input_shapes=None, **kwargs): """Validates the tf.keras model converts to a TFLite model. Converts the tf.keras model to TFLite and checks the accuracy of the model on @@ -305,9 +305,15 @@ def test_keras_model(filename, **kwargs): Args: filename: Full filepath of HDF5 file containing the tf.keras model. + input_arrays: List of input tensors to freeze graph with. + input_shapes: Dict of strings representing input tensor names to list of + integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}). + Automatically determined when input shapes is None (e.g., {"foo" : None}). + (default None) **kwargs: Additional arguments to be passed into the converter. """ - converter = _lite.TFLiteConverter.from_keras_model_file(filename) + converter = _lite.TFLiteConverter.from_keras_model_file( + filename, input_arrays=input_arrays, input_shapes=input_shapes) tflite_model = _convert(converter, **kwargs) tf_eval_func = evaluate_keras_model(filename) -- cgit v1.2.3