aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-10-09 10:58:03 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 11:02:26 -0700
commit3e8af7ea6b70104b05be22797451d0218c9e5262 (patch)
treeb11d127d46ffc1847e8a94ab500ca11f0b4624b1
parentaa8f428a9310b3fd8371bddf612e480b27618b2e (diff)
Internal change.
PiperOrigin-RevId: 216385202
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
index 72029ed03c..ab29f71138 100644
--- a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -297,7 +297,7 @@ def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
compare_models_random_data(tflite_model, tf_eval_func)
-def test_keras_model(filename, **kwargs):
+def test_keras_model(filename, input_arrays=None, input_shapes=None, **kwargs):
"""Validates the tf.keras model converts to a TFLite model.
Converts the tf.keras model to TFLite and checks the accuracy of the model on
@@ -305,9 +305,15 @@ def test_keras_model(filename, **kwargs):
Args:
filename: Full filepath of HDF5 file containing the tf.keras model.
+ input_arrays: List of input tensors to freeze graph with.
+ input_shapes: Dict of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo" : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
**kwargs: Additional arguments to be passed into the converter.
"""
- converter = _lite.TFLiteConverter.from_keras_model_file(filename)
+ converter = _lite.TFLiteConverter.from_keras_model_file(
+ filename, input_arrays=input_arrays, input_shapes=input_shapes)
tflite_model = _convert(converter, **kwargs)
tf_eval_func = evaluate_keras_model(filename)