aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Nupur Garg <nupurgarg@google.com>2018-09-25 14:46:09 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 14:50:12 -0700
commit20c71535c5f1ed1d918d6cc6e327ffbba49ecbd6 (patch)
tree16bdfe0e55215f64df518e3fbc1bcf9c01319c42
parent038d15d8e2037d4a45e60e076429d67ec7d5ace1 (diff)
Internal change.
PiperOrigin-RevId: 214507546
-rw-r--r--tensorflow/contrib/lite/build_def.bzl38
-rw-r--r--tensorflow/contrib/lite/python/BUILD2
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py12
-rw-r--r--tensorflow/contrib/lite/testing/BUILD27
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py241
-rw-r--r--tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py130
6 files changed, 443 insertions, 7 deletions
diff --git a/tensorflow/contrib/lite/build_def.bzl b/tensorflow/contrib/lite/build_def.bzl
index fc4d9b4f17..7f5c6bdc2f 100644
--- a/tensorflow/contrib/lite/build_def.bzl
+++ b/tensorflow/contrib/lite/build_def.bzl
@@ -391,3 +391,41 @@ def gen_selected_ops(name, model):
(tool, model, out, tflite_path[2:]),
tools = [tool],
)
+
+def gen_full_model_test(conversion_modes, models, data, test_suite_tag):
+ """Generates Python test targets for testing TFLite models.
+
+ Args:
+ conversion_modes: List of conversion modes to test the models on.
+ models: List of models to test.
+ data: List of BUILD targets linking the data.
+ test_suite_tag: Tag identifying the model test suite.
+ """
+ options = [
+ (conversion_mode, model)
+ for model in models
+ for conversion_mode in conversion_modes
+ ]
+
+ for conversion_mode, model_name in options:
+ native.py_test(
+ name = "model_coverage_test_%s_%s" % (model_name, conversion_mode.lower()),
+ srcs = ["model_coverage_test.py"],
+ main = "model_coverage_test.py",
+ args = [
+ "--model_name=%s" % model_name,
+ "--converter_mode=%s" % conversion_mode,
+ ],
+ data = data,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_windows",
+ "notap",
+ ] + [test_suite_tag],
+ deps = [
+ "//tensorflow/contrib/lite/testing:model_coverage_lib",
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:client_testlib",
+ ],
+ )
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 57e1290e07..916788f215 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -144,7 +144,7 @@ py_library(
name = "convert_saved_model",
srcs = ["convert_saved_model.py"],
srcs_version = "PY2AND3",
- visibility = ["//visibility:public"],
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
deps = [
":convert",
"//tensorflow/contrib/saved_model:saved_model_py",
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index 1553464b9f..d18b60d0ea 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -44,7 +44,7 @@ def _log_tensor_details(tensor_info):
dtype)
-def _get_meta_graph_def(saved_model_dir, tag_set):
+def get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
@@ -61,7 +61,7 @@ def _get_meta_graph_def(saved_model_dir, tag_set):
return loader.load(sess, tag_set, saved_model_dir)
-def _get_signature_def(meta_graph, signature_key):
+def get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
@@ -86,7 +86,7 @@ def _get_signature_def(meta_graph, signature_key):
return signature_def_map[signature_key]
-def _get_inputs_outputs(signature_def):
+def get_inputs_outputs(signature_def):
"""Get inputs and outputs from SignatureDef.
Args:
@@ -236,9 +236,9 @@ def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
- meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
- signature_def = _get_signature_def(meta_graph, signature_key)
- inputs, outputs = _get_inputs_outputs(signature_def)
+ meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = get_signature_def(meta_graph, signature_key)
+ inputs, outputs = get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
diff --git a/tensorflow/contrib/lite/testing/BUILD b/tensorflow/contrib/lite/testing/BUILD
index a4736bfee9..c4a2b03444 100644
--- a/tensorflow/contrib/lite/testing/BUILD
+++ b/tensorflow/contrib/lite/testing/BUILD
@@ -13,6 +13,7 @@ load("//tensorflow/contrib/lite:special_rules.bzl", "tflite_portable_test_suite"
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
+ "py_test",
)
[gen_zip_test(
@@ -362,4 +363,30 @@ cc_binary(
],
)
+py_binary(
+ name = "model_coverage_lib",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow/contrib/lite:__subpackages__"],
+ deps = [
+ "//tensorflow/contrib/lite/python:lite",
+ "//tensorflow/python:platform",
+ ],
+)
+
+py_test(
+ name = "model_coverage_lib_test",
+ srcs = ["//tensorflow/contrib/lite/testing:model_coverage/model_coverage_lib_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_windows",
+ "notap",
+ ],
+ deps = [
+ ":model_coverage_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
tflite_portable_test_suite()
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
new file mode 100644
index 0000000000..f8ab394c60
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib.py
@@ -0,0 +1,241 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functions to test TFLite models."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.lite.python import convert_saved_model as _convert_saved_model
+from tensorflow.contrib.lite.python import lite as _lite
+from tensorflow.core.framework import graph_pb2 as _graph_pb2
+from tensorflow.python import keras as _keras
+from tensorflow.python.client import session as _session
+from tensorflow.python.framework.importer import import_graph_def as _import_graph_def
+from tensorflow.python.lib.io import file_io as _file_io
+from tensorflow.python.saved_model import signature_constants as _signature_constants
+from tensorflow.python.saved_model import tag_constants as _tag_constants
+
+
+def _convert(converter, **kwargs):
+ """Converts the model.
+
+ Args:
+ converter: TocoConverter object.
+ **kwargs: Additional arguments to be passed into the converter. Supported
+ flags are {"converter_mode", "post_training_quant"}.
+
+ Returns:
+ The converted TFLite model in serialized format.
+ """
+ if "converter_mode" in kwargs:
+ converter.converter_mode = kwargs["converter_mode"]
+ if "post_training_quantize" in kwargs:
+ converter.post_training_quantize = kwargs["post_training_quantize"]
+ return converter.convert()
+
+
+def _generate_random_input_data(tflite_model, seed=None):
+ """Generates input data based on the input tensors in the TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ seed: Integer seed for the random generator. (default None)
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+ input_details = interpreter.get_input_details()
+
+ if seed:
+ np.random.seed(seed=seed)
+ return [
+ np.array(
+ np.random.random_sample(input_tensor["shape"]),
+ dtype=input_tensor["dtype"]) for input_tensor in input_details
+ ]
+
+
+def _evaluate_tflite_model(tflite_model, input_data):
+ """Returns evaluation of input data on TFLite model.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ input_data: List of np.ndarray.
+
+ Returns:
+ List of np.ndarray.
+ """
+ interpreter = _lite.Interpreter(model_content=tflite_model)
+ interpreter.allocate_tensors()
+
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ for input_tensor, tensor_data in zip(input_details, input_data):
+ interpreter.set_tensor(input_tensor["index"], tensor_data)
+
+ interpreter.invoke()
+ output_data = [
+ interpreter.get_tensor(output_tensor["index"])
+ for output_tensor in output_details
+ ]
+ return output_data
+
+
+def evaluate_frozen_graph(filename, input_arrays, output_arrays):
+ """Returns a function that evaluates the frozen graph on input data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ with _file_io.FileIO(filename, "rb") as f:
+ file_content = f.read()
+
+ graph_def = _graph_pb2.GraphDef()
+ graph_def.ParseFromString(file_content)
+ _import_graph_def(graph_def, name="")
+
+ inputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, input_arrays)
+ outputs = _convert_saved_model.get_tensors_from_tensor_names(
+ sess.graph, output_arrays)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_saved_model(directory, tag_set, signature_key):
+ """Returns a function that evaluates the SavedModel on input data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ with _session.Session().as_default() as sess:
+ if tag_set is None:
+ tag_set = set([_tag_constants.SERVING])
+ if signature_key is None:
+ signature_key = _signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
+
+ meta_graph = _convert_saved_model.get_meta_graph_def(directory, tag_set)
+ signature_def = _convert_saved_model.get_signature_def(
+ meta_graph, signature_key)
+ inputs, outputs = _convert_saved_model.get_inputs_outputs(signature_def)
+
+ return lambda input_data: sess.run(outputs, dict(zip(inputs, input_data)))
+
+
+def evaluate_keras_model(filename):
+ """Returns a function that evaluates the tf.keras model on input data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+
+ Returns:
+ Lambda function ([np.ndarray data] : [np.ndarray result]).
+ """
+ keras_model = _keras.models.load_model(filename)
+ return lambda input_data: [keras_model.predict(input_data)]
+
+
+# TODO(nupurgarg): Make this function a parameter to test_frozen_graph (and
+# related functions) in order to make it easy to use different data generators.
+def compare_models_random_data(tflite_model, tf_eval_func, tolerance=5):
+ """Compares TensorFlow and TFLite models with random data.
+
+ Args:
+ tflite_model: Serialized TensorFlow Lite model.
+ tf_eval_func: Lambda function that takes in input data and outputs the
+ results of the TensorFlow model ([np.ndarray data] : [np.ndarray result]).
+ tolerance: Decimal place to check accuracy to.
+ """
+ input_data = _generate_random_input_data(tflite_model)
+ tf_results = tf_eval_func(input_data)
+ tflite_results = _evaluate_tflite_model(tflite_model, input_data)
+ for tf_result, tflite_result in zip(tf_results, tflite_results):
+ np.testing.assert_almost_equal(tf_result, tflite_result, tolerance)
+
+
+def test_frozen_graph(filename, input_arrays, output_arrays, **kwargs):
+ """Validates the TensorFlow frozen graph converts to a TFLite model.
+
+ Converts the TensorFlow frozen graph to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ filename: Full filepath of file containing frozen GraphDef.
+ input_arrays: List of input tensors to freeze graph with.
+ output_arrays: List of output tensors to freeze graph with.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_frozen_graph(filename, input_arrays,
+ output_arrays)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_frozen_graph(filename, input_arrays, output_arrays)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_saved_model(directory, tag_set=None, signature_key=None, **kwargs):
+ """Validates the TensorFlow SavedModel converts to a TFLite model.
+
+ Converts the TensorFlow SavedModel to TFLite and checks the accuracy of the
+ model on random data.
+
+ Args:
+ directory: SavedModel directory to convert.
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present.
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_saved_model(directory, tag_set,
+ signature_key)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_saved_model(directory, tag_set, signature_key)
+ compare_models_random_data(tflite_model, tf_eval_func)
+
+
+def test_keras_model(filename, **kwargs):
+ """Validates the tf.keras model converts to a TFLite model.
+
+ Converts the tf.keras model to TFLite and checks the accuracy of the model on
+ random data.
+
+ Args:
+ filename: Full filepath of HDF5 file containing the tf.keras model.
+ **kwargs: Additional arguments to be passed into the converter.
+ """
+ converter = _lite.TocoConverter.from_keras_model_file(filename)
+ tflite_model = _convert(converter, **kwargs)
+
+ tf_eval_func = evaluate_keras_model(filename)
+ compare_models_random_data(tflite_model, tf_eval_func)
diff --git a/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
new file mode 100644
index 0000000000..5f3355e734
--- /dev/null
+++ b/tensorflow/contrib/lite/testing/model_coverage/model_coverage_lib_test.py
@@ -0,0 +1,130 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for model_coverage_lib.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tempfile
+
+from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.testing.model_coverage import model_coverage_lib as model_coverage
+from tensorflow.python import keras
+from tensorflow.python.client import session
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+from tensorflow.python.saved_model import saved_model
+from tensorflow.python.training.training_util import write_graph
+
+
+class EvaluateFrozenGraph(test.TestCase):
+
+ def _saveFrozenGraph(self, sess):
+ graph_def_file = os.path.join(self.get_temp_dir(), 'model.pb')
+ write_graph(sess.graph_def, '', graph_def_file, False)
+ return graph_def_file
+
+ def testFloat(self):
+ with session.Session().as_default() as sess:
+ in_tensor = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32)
+ _ = in_tensor + in_tensor
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['Placeholder'], ['add'])
+
+ def testMultipleOutputs(self):
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputA')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16], dtype=dtypes.float32, name='inputB')
+
+ weight = constant_op.constant(-1.0, shape=[16, 16])
+ bias = constant_op.constant(-1.0, shape=[16])
+ layer = math_ops.matmul(in_tensor_1, weight) + bias
+ _ = math_ops.reduce_mean(math_ops.square(layer - in_tensor_2))
+ filename = self._saveFrozenGraph(sess)
+
+ model_coverage.test_frozen_graph(filename, ['inputA', 'inputB'],
+ ['add', 'Mean'])
+
+
+class EvaluateSavedModel(test.TestCase):
+
+ def testFloat(self):
+ saved_model_dir = os.path.join(self.get_temp_dir(), 'simple_savedmodel')
+ with session.Session().as_default() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputB')
+ in_tensor_2 = array_ops.placeholder(
+ shape=[1, 16, 16, 3], dtype=dtypes.float32, name='inputA')
+ out_tensor = in_tensor_1 + in_tensor_2
+
+ inputs = {'x': in_tensor_1, 'y': in_tensor_2}
+ outputs = {'z': out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ model_coverage.test_saved_model(saved_model_dir)
+
+
+class EvaluateKerasModel(test.TestCase):
+
+ def _getSingleInputKerasModel(self):
+ """Returns single input Sequential tf.keras model."""
+ keras.backend.clear_session()
+
+ xs = [-1, 0, 1, 2, 3, 4]
+ ys = [-3, -1, 1, 3, 5, 7]
+
+ model = keras.Sequential([keras.layers.Dense(units=1, input_shape=[1])])
+ model.compile(optimizer='sgd', loss='mean_squared_error')
+ model.train_on_batch(xs, ys)
+ return model
+
+ def _saveKerasModel(self, model):
+ try:
+ fd, keras_file = tempfile.mkstemp('.h5')
+ keras.models.save_model(model, keras_file)
+ finally:
+ os.close(fd)
+ return keras_file
+
+ def testFloat(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file)
+
+ def testPostTrainingQuantize(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(keras_file, post_training_quantize=True)
+
+ def testConverterMode(self):
+ model = self._getSingleInputKerasModel()
+ keras_file = self._saveKerasModel(model)
+
+ model_coverage.test_keras_model(
+ keras_file, converter_mode=lite.ConverterMode.TOCO_EXTENDED)
+
+
+if __name__ == '__main__':
+ test.main()