aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tutorials/mnist_tflite.py
blob: 7b8bf5b5dbc8462d859c189af16c461244bfc374 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf  # pylint: disable=g-bad-import-order
from tensorflow.contrib.lite.tutorials import dataset
flags = tf.app.flags

flags.DEFINE_string('data_dir', '/tmp/data_dir',
                    'Directory where data is stored.')
flags.DEFINE_string('model_file', '',
                    'The path to the TFLite flatbuffer model file.')


flags = flags.FLAGS


def test_image_generator():
  # Generates an iterator over images
  with tf.Session() as sess:
    input_data = dataset.test(
        flags.data_dir).make_one_shot_iterator().get_next()
    try:
      while True:
        yield sess.run(input_data)
    except tf.errors.OutOfRangeError:
      pass


def run_eval(interpreter, input_image):
  """Performs evaluation for input image over specified model.

  Args:
      interpreter: TFLite interpreter initialized with model to execute.
      input_image: Image input to the model.

  Returns:
      output: output tensor of model being executed.
  """

  # Get input and output tensors.
  input_details = interpreter.get_input_details()
  output_details = interpreter.get_output_details()

  # Test model on the input images.
  input_image = np.reshape(input_image, input_details[0]['shape'])
  interpreter.set_tensor(input_details[0]['index'], input_image)

  interpreter.invoke()
  output_data = interpreter.get_tensor(output_details[0]['index'])
  output = np.squeeze(output_data)
  return output


def main(_):
  interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
  interpreter.allocate_tensors()
  num_correct, total = 0, 0
  for input_data in test_image_generator():
    output = run_eval(interpreter, input_data[0])
    total += 1
    if output == input_data[1]:
      num_correct += 1
    if total % 500 == 0:
      print('Accuracy after %i images: %f' %
            (total, float(num_correct) / float(total)))


if __name__ == '__main__':
  tf.logging.set_verbosity(tf.logging.INFO)
  tf.app.run(main)