diff options
Diffstat (limited to 'tensorflow/contrib/lite/tutorials/mnist_tflite.py')
-rw-r--r-- | tensorflow/contrib/lite/tutorials/mnist_tflite.py | 87 |
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py new file mode 100644 index 0000000000..7b8bf5b5db --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import tensorflow as tf # pylint: disable=g-bad-import-order +from tensorflow.contrib.lite.tutorials import dataset +flags = tf.app.flags + +flags.DEFINE_string('data_dir', '/tmp/data_dir', + 'Directory where data is stored.') +flags.DEFINE_string('model_file', '', + 'The path to the TFLite flatbuffer model file.') + + +flags = flags.FLAGS + + +def test_image_generator(): + # Generates an iterator over images + with tf.Session() as sess: + input_data = dataset.test( + flags.data_dir).make_one_shot_iterator().get_next() + try: + while True: + yield sess.run(input_data) + except tf.errors.OutOfRangeError: + pass + + +def run_eval(interpreter, input_image): + """Performs evaluation for input image over specified model. + + Args: + interpreter: TFLite interpreter initialized with model to execute. + input_image: Image input to the model. + + Returns: + output: output tensor of model being executed. + """ + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Test model on the input images. + input_image = np.reshape(input_image, input_details[0]['shape']) + interpreter.set_tensor(input_details[0]['index'], input_image) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + output = np.squeeze(output_data) + return output + + +def main(_): + interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file) + interpreter.allocate_tensors() + num_correct, total = 0, 0 + for input_data in test_image_generator(): + output = run_eval(interpreter, input_data[0]) + total += 1 + if output == input_data[1]: + num_correct += 1 + if total % 500 == 0: + print('Accuracy after %i images: %f' % + (total, float(num_correct) / float(total))) + + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run(main) |