aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tutorials/mnist_tflite.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/lite/tutorials/mnist_tflite.py')
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
1 files changed, 87 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)