diff options
author | 2018-09-06 15:39:41 -0700 | |
---|---|---|
committer | 2018-09-06 15:44:05 -0700 | |
commit | 5aca2604651e3532aa5304b6aabf51f630e62084 (patch) | |
tree | 39afbeae2948acdf20ce09972f10dac14538e064 /tensorflow/contrib/lite/tutorials | |
parent | 3142d94dd2258b4b04ac9857341a6736ed1f4442 (diff) |
Python example for tutorial on post training quantization for mnist.
PiperOrigin-RevId: 211882134
Diffstat (limited to 'tensorflow/contrib/lite/tutorials')
-rw-r--r-- | tensorflow/contrib/lite/tutorials/BUILD | 20 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tutorials/dataset.py | 122 | ||||
-rw-r--r-- | tensorflow/contrib/lite/tutorials/mnist_tflite.py | 87 |
3 files changed, 229 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD new file mode 100644 index 0000000000..67ff1ea124 --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/BUILD @@ -0,0 +1,20 @@ +# Example Estimator model + +package( + default_visibility = ["//visibility:public"], +) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +py_binary( + name = "mnist_tflite", + srcs = [ + "dataset.py", + "mnist_tflite.py", + ], + deps = [ + "//tensorflow:tensorflow_py", + ], +) diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py new file mode 100644 index 0000000000..ba49dfcc9b --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/dataset.py @@ -0,0 +1,122 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""tf.data.Dataset interface to the MNIST dataset. + + This is cloned from + https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import gzip +import os +import shutil +import tempfile + +import numpy as np +from six.moves import urllib +import tensorflow as tf + + +def read32(bytestream): + """Read 4 bytes from bytestream as an unsigned 32-bit integer.""" + dt = np.dtype(np.uint32).newbyteorder('>') + return np.frombuffer(bytestream.read(4), dtype=dt)[0] + + +def check_image_file_header(filename): + """Validate that filename corresponds to images for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_images, unused + rows = read32(f) + cols = read32(f) + if magic != 2051: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + if rows != 28 or cols != 28: + raise ValueError( + 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' % + (f.name, rows, cols)) + + +def check_labels_file_header(filename): + """Validate that filename corresponds to labels for the MNIST dataset.""" + with tf.gfile.Open(filename, 'rb') as f: + magic = read32(f) + read32(f) # num_items, unused + if magic != 2049: + raise ValueError('Invalid magic number %d in MNIST file %s' % (magic, + f.name)) + + +def download(directory, filename): + """Download (and unzip) a file from the MNIST dataset if not already done.""" + filepath = os.path.join(directory, filename) + if tf.gfile.Exists(filepath): + return filepath + if not tf.gfile.Exists(directory): + tf.gfile.MakeDirs(directory) + # CVDF mirror of http://yann.lecun.com/exdb/mnist/ + url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz' + _, zipped_filepath = tempfile.mkstemp(suffix='.gz') + print('Downloading %s to %s' % (url, zipped_filepath)) + urllib.request.urlretrieve(url, zipped_filepath) + with gzip.open(zipped_filepath, 'rb') as f_in, \ + tf.gfile.Open(filepath, 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(zipped_filepath) + return filepath + + +def dataset(directory, images_file, labels_file): + """Download and parse MNIST dataset.""" + + images_file = download(directory, images_file) + labels_file = download(directory, labels_file) + + check_image_file_header(images_file) + check_labels_file_header(labels_file) + + def decode_image(image): + # Normalize from [0, 255] to [0.0, 1.0] + image = tf.decode_raw(image, tf.uint8) + image = tf.cast(image, tf.float32) + image = tf.reshape(image, [784]) + return image / 255.0 + + def decode_label(label): + label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8] + label = tf.reshape(label, []) # label is a scalar + return tf.to_int32(label) + + images = tf.data.FixedLengthRecordDataset( + images_file, 28 * 28, header_bytes=16).map(decode_image) + labels = tf.data.FixedLengthRecordDataset( + labels_file, 1, header_bytes=8).map(decode_label) + return tf.data.Dataset.zip((images, labels)) + + +def train(directory): + """tf.data.Dataset object for MNIST training data.""" + return dataset(directory, 'train-images-idx3-ubyte', + 'train-labels-idx1-ubyte') + + +def test(directory): + """tf.data.Dataset object for MNIST test data.""" + return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte') diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py new file mode 100644 index 0000000000..7b8bf5b5db --- /dev/null +++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py @@ -0,0 +1,87 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import numpy as np +import tensorflow as tf # pylint: disable=g-bad-import-order +from tensorflow.contrib.lite.tutorials import dataset +flags = tf.app.flags + +flags.DEFINE_string('data_dir', '/tmp/data_dir', + 'Directory where data is stored.') +flags.DEFINE_string('model_file', '', + 'The path to the TFLite flatbuffer model file.') + + +flags = flags.FLAGS + + +def test_image_generator(): + # Generates an iterator over images + with tf.Session() as sess: + input_data = dataset.test( + flags.data_dir).make_one_shot_iterator().get_next() + try: + while True: + yield sess.run(input_data) + except tf.errors.OutOfRangeError: + pass + + +def run_eval(interpreter, input_image): + """Performs evaluation for input image over specified model. + + Args: + interpreter: TFLite interpreter initialized with model to execute. + input_image: Image input to the model. + + Returns: + output: output tensor of model being executed. + """ + + # Get input and output tensors. + input_details = interpreter.get_input_details() + output_details = interpreter.get_output_details() + + # Test model on the input images. + input_image = np.reshape(input_image, input_details[0]['shape']) + interpreter.set_tensor(input_details[0]['index'], input_image) + + interpreter.invoke() + output_data = interpreter.get_tensor(output_details[0]['index']) + output = np.squeeze(output_data) + return output + + +def main(_): + interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file) + interpreter.allocate_tensors() + num_correct, total = 0, 0 + for input_data in test_image_generator(): + output = run_eval(interpreter, input_data[0]) + total += 1 + if output == input_data[1]: + num_correct += 1 + if total % 500 == 0: + print('Accuracy after %i images: %f' % + (total, float(num_correct) / float(total))) + + +if __name__ == '__main__': + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run(main) |