aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/lite/tutorials
diff options
context:
space:
mode:
authorGravatar Raghuraman Krishnamoorthi <raghuramank@google.com>2018-09-06 15:39:41 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 15:44:05 -0700
commit5aca2604651e3532aa5304b6aabf51f630e62084 (patch)
tree39afbeae2948acdf20ce09972f10dac14538e064 /tensorflow/contrib/lite/tutorials
parent3142d94dd2258b4b04ac9857341a6736ed1f4442 (diff)
Python example for tutorial on post training quantization for mnist.
PiperOrigin-RevId: 211882134
Diffstat (limited to 'tensorflow/contrib/lite/tutorials')
-rw-r--r--tensorflow/contrib/lite/tutorials/BUILD20
-rw-r--r--tensorflow/contrib/lite/tutorials/dataset.py122
-rw-r--r--tensorflow/contrib/lite/tutorials/mnist_tflite.py87
3 files changed, 229 insertions, 0 deletions
diff --git a/tensorflow/contrib/lite/tutorials/BUILD b/tensorflow/contrib/lite/tutorials/BUILD
new file mode 100644
index 0000000000..67ff1ea124
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/BUILD
@@ -0,0 +1,20 @@
+# Example Estimator model
+
+package(
+ default_visibility = ["//visibility:public"],
+)
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "mnist_tflite",
+ srcs = [
+ "dataset.py",
+ "mnist_tflite.py",
+ ],
+ deps = [
+ "//tensorflow:tensorflow_py",
+ ],
+)
diff --git a/tensorflow/contrib/lite/tutorials/dataset.py b/tensorflow/contrib/lite/tutorials/dataset.py
new file mode 100644
index 0000000000..ba49dfcc9b
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/dataset.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""tf.data.Dataset interface to the MNIST dataset.
+
+ This is cloned from
+ https://github.com/tensorflow/models/blob/master/official/mnist/dataset.py
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import shutil
+import tempfile
+
+import numpy as np
+from six.moves import urllib
+import tensorflow as tf
+
+
+def read32(bytestream):
+ """Read 4 bytes from bytestream as an unsigned 32-bit integer."""
+ dt = np.dtype(np.uint32).newbyteorder('>')
+ return np.frombuffer(bytestream.read(4), dtype=dt)[0]
+
+
+def check_image_file_header(filename):
+ """Validate that filename corresponds to images for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_images, unused
+ rows = read32(f)
+ cols = read32(f)
+ if magic != 2051:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+ if rows != 28 or cols != 28:
+ raise ValueError(
+ 'Invalid MNIST file %s: Expected 28x28 images, found %dx%d' %
+ (f.name, rows, cols))
+
+
+def check_labels_file_header(filename):
+ """Validate that filename corresponds to labels for the MNIST dataset."""
+ with tf.gfile.Open(filename, 'rb') as f:
+ magic = read32(f)
+ read32(f) # num_items, unused
+ if magic != 2049:
+ raise ValueError('Invalid magic number %d in MNIST file %s' % (magic,
+ f.name))
+
+
+def download(directory, filename):
+ """Download (and unzip) a file from the MNIST dataset if not already done."""
+ filepath = os.path.join(directory, filename)
+ if tf.gfile.Exists(filepath):
+ return filepath
+ if not tf.gfile.Exists(directory):
+ tf.gfile.MakeDirs(directory)
+ # CVDF mirror of http://yann.lecun.com/exdb/mnist/
+ url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'
+ _, zipped_filepath = tempfile.mkstemp(suffix='.gz')
+ print('Downloading %s to %s' % (url, zipped_filepath))
+ urllib.request.urlretrieve(url, zipped_filepath)
+ with gzip.open(zipped_filepath, 'rb') as f_in, \
+ tf.gfile.Open(filepath, 'wb') as f_out:
+ shutil.copyfileobj(f_in, f_out)
+ os.remove(zipped_filepath)
+ return filepath
+
+
+def dataset(directory, images_file, labels_file):
+ """Download and parse MNIST dataset."""
+
+ images_file = download(directory, images_file)
+ labels_file = download(directory, labels_file)
+
+ check_image_file_header(images_file)
+ check_labels_file_header(labels_file)
+
+ def decode_image(image):
+ # Normalize from [0, 255] to [0.0, 1.0]
+ image = tf.decode_raw(image, tf.uint8)
+ image = tf.cast(image, tf.float32)
+ image = tf.reshape(image, [784])
+ return image / 255.0
+
+ def decode_label(label):
+ label = tf.decode_raw(label, tf.uint8) # tf.string -> [tf.uint8]
+ label = tf.reshape(label, []) # label is a scalar
+ return tf.to_int32(label)
+
+ images = tf.data.FixedLengthRecordDataset(
+ images_file, 28 * 28, header_bytes=16).map(decode_image)
+ labels = tf.data.FixedLengthRecordDataset(
+ labels_file, 1, header_bytes=8).map(decode_label)
+ return tf.data.Dataset.zip((images, labels))
+
+
+def train(directory):
+ """tf.data.Dataset object for MNIST training data."""
+ return dataset(directory, 'train-images-idx3-ubyte',
+ 'train-labels-idx1-ubyte')
+
+
+def test(directory):
+ """tf.data.Dataset object for MNIST test data."""
+ return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')
diff --git a/tensorflow/contrib/lite/tutorials/mnist_tflite.py b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
new file mode 100644
index 0000000000..7b8bf5b5db
--- /dev/null
+++ b/tensorflow/contrib/lite/tutorials/mnist_tflite.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Script to evaluate accuracy of TFLite flatbuffer model on mnist dataset."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+import numpy as np
+import tensorflow as tf # pylint: disable=g-bad-import-order
+from tensorflow.contrib.lite.tutorials import dataset
+flags = tf.app.flags
+
+flags.DEFINE_string('data_dir', '/tmp/data_dir',
+ 'Directory where data is stored.')
+flags.DEFINE_string('model_file', '',
+ 'The path to the TFLite flatbuffer model file.')
+
+
+flags = flags.FLAGS
+
+
+def test_image_generator():
+ # Generates an iterator over images
+ with tf.Session() as sess:
+ input_data = dataset.test(
+ flags.data_dir).make_one_shot_iterator().get_next()
+ try:
+ while True:
+ yield sess.run(input_data)
+ except tf.errors.OutOfRangeError:
+ pass
+
+
+def run_eval(interpreter, input_image):
+ """Performs evaluation for input image over specified model.
+
+ Args:
+ interpreter: TFLite interpreter initialized with model to execute.
+ input_image: Image input to the model.
+
+ Returns:
+ output: output tensor of model being executed.
+ """
+
+ # Get input and output tensors.
+ input_details = interpreter.get_input_details()
+ output_details = interpreter.get_output_details()
+
+ # Test model on the input images.
+ input_image = np.reshape(input_image, input_details[0]['shape'])
+ interpreter.set_tensor(input_details[0]['index'], input_image)
+
+ interpreter.invoke()
+ output_data = interpreter.get_tensor(output_details[0]['index'])
+ output = np.squeeze(output_data)
+ return output
+
+
+def main(_):
+ interpreter = tf.contrib.lite.Interpreter(model_path=flags.model_file)
+ interpreter.allocate_tensors()
+ num_correct, total = 0, 0
+ for input_data in test_image_generator():
+ output = run_eval(interpreter, input_data[0])
+ total += 1
+ if output == input_data[1]:
+ num_correct += 1
+ if total % 500 == 0:
+ print('Accuracy after %i images: %f' %
+ (total, float(num_correct) / float(total)))
+
+
+if __name__ == '__main__':
+ tf.logging.set_verbosity(tf.logging.INFO)
+ tf.app.run(main)