aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/how_tos
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2015-12-07 15:48:00 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2015-12-07 15:48:00 -0800
commitcd53f3c3302c9312c1840389a9988a879b8b9dd5 (patch)
treed5bc93bb81e66fa638f9ecde328d7684e870e3ff /tensorflow/examples/how_tos
parent11e3d0faf251f4b6ed8f37a27fe5f2509b9a2457 (diff)
TensorFlow: upstrea changes from git.
Change 109628097 Fix gcc 4.8.1 compile Modified from patch by @assolini here: https://github.com/tensorflow/tensorflow/issues/405 Change 109624275 Make preview frame ImageReader global so that it does not get GC'd. This may fix an issue with connecting to the camera on some devices where the underlying Surface is prematurely cleaned up (http://stackoverflow.com/questions/33437961/android-camera-2-api-bufferqueue-has-been-abandoned). Change 109620599 - improved test a little to make it easier to understand as it serves as an example for users Change 109614953 TensorFlow: update tutorials/howtos to point to correct location of files, show python example in addition to bazel. Change 109612732 TensorFlow: move reading_data into examples, change data dir to /tmp/data. Validated that they all run, but these probably need a selftest at some point. Change 109608695 Apply 'gate_gradients' only when there is more than one real gradients. Change 109605014 There are 3 obvious places to start using TensorFlow. 2/3 of the starting points do not have a link to the installation instructions. Change 109604287 Make the `tf.reshape` shape function more restrictive. Previously, it did not raise a construction-time error if the input shape and the new shape were incompatible; now it detects this and raises a `ValueError`. Change 109603375 TensorFlow: Move word2vec_basic.py from g3doc/ to examples/ There are no additional libraries this uses, so nothing else needs to be done Change 109601289 TensorBoard tag 3 Change 109600908 Decrease number of scalar values stored by TensorBoard. 10k is more than displays nicely. Change 109599464 Fix "smart restart" functionality in TensorBoard (it throws away dead data) After restarts, a file_version event is created that always has step 0. We need to ignore this. Change 109597667 Switch to using /dev/urandom for TensorFlow randomness. Using /dev/random leads to slowdown when running in an environment with poor access to an entropy source (such as some VMs). /dev/urandom has more predictable performance, and we don't require cryptographically secure random number generation, so a PRNG is good enough. Also removes the use of the RNG in DirectSession construction. This was being used to generate a session handle, which is not necessary (since a DirectSession owns its devices, we don't need a unique handle to key the OpSegment objects registered with the various devices). This addresses bugs that have been reported on the mailing list and Stack Overflow. Change 109596906 Add an is_unsigned property to dtype Change 109596830 Remove unnecessary fill in clip_by_value Change 109591880 Remove Android demo's libpthread.so dummy file (required by protobuf) from repo and generate it at compile-time. This makes the Android demo more portable, as the generated file will now always be the correct archictecture for linking. Change 109589028 Isolating out the RTTI part of TensorFlow and add non-RTTI backups for Android. This saves about 400KB of the compiled library, when compiling the Android tensorflow target with -fno-rtti. Change 109589018 Internal reworking of LSTMCell. Change 109588229 Allow bool-valued tensors to be persisted. Change 109577175 TensorBoard host defaults to 0.0.0.0 Change 109551438 TensorFlow: move mnist g3doc tutorials into tensorflow/examples. Update examples to point to the correct location. Adds tests to make sure they don't regress, do some lint cleanup. Base CL: 109630240
Diffstat (limited to 'tensorflow/examples/how_tos')
-rw-r--r--tensorflow/examples/how_tos/reading_data/BUILD68
-rw-r--r--tensorflow/examples/how_tos/reading_data/__init__.py0
-rw-r--r--tensorflow/examples/how_tos/reading_data/convert_to_records.py105
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py158
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py169
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py197
6 files changed, 697 insertions, 0 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/BUILD b/tensorflow/examples/how_tos/reading_data/BUILD
new file mode 100644
index 0000000000..c1e773d905
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/BUILD
@@ -0,0 +1,68 @@
+# Description:
+# Example MNIST TensorFlow models for demonstrating data reading.
+
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_binary(
+ name = "convert_to_records",
+ srcs = ["convert_to_records.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_reader",
+ srcs = [
+ "fully_connected_reader.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_preloaded",
+ srcs = [
+ "fully_connected_preloaded.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+py_binary(
+ name = "fully_connected_preloaded_var",
+ srcs = [
+ "fully_connected_preloaded_var.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/examples/tutorials/mnist",
+ "//tensorflow/examples/tutorials/mnist:input_data",
+ ],
+)
+
+filegroup(
+ name = "all_files",
+ srcs = glob(
+ ["**/*"],
+ exclude = [
+ "**/METADATA",
+ "**/OWNERS",
+ ],
+ ),
+ visibility = ["//tensorflow:__subpackages__"],
+)
diff --git a/tensorflow/examples/how_tos/reading_data/__init__.py b/tensorflow/examples/how_tos/reading_data/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/__init__.py
diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
new file mode 100644
index 0000000000..30b5a384a8
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
@@ -0,0 +1,105 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Converts MNIST data to TFRecords file format with Example protos."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+import tensorflow.python.platform
+
+import numpy
+import tensorflow as tf
+from tensorflow.examples.tutorials.mnist import input_data
+
+
+TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames
+TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
+TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
+TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
+
+
+tf.app.flags.DEFINE_string('directory', '/tmp/data',
+ 'Directory to download data files and write the '
+ 'converted result')
+tf.app.flags.DEFINE_integer('validation_size', 5000,
+ 'Number of examples to separate from the training '
+ 'data for the validation set.')
+FLAGS = tf.app.flags.FLAGS
+
+
+def _int64_feature(value):
+ return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
+
+
+def _bytes_feature(value):
+ return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
+
+
+def convert_to(images, labels, name):
+ num_examples = labels.shape[0]
+ if images.shape[0] != num_examples:
+ raise ValueError("Images size %d does not match label size %d." %
+ (images.shape[0], num_examples))
+ rows = images.shape[1]
+ cols = images.shape[2]
+ depth = images.shape[3]
+
+ filename = os.path.join(FLAGS.directory, name + '.tfrecords')
+ print('Writing', filename)
+ writer = tf.python_io.TFRecordWriter(filename)
+ for index in range(num_examples):
+ image_raw = images[index].tostring()
+ example = tf.train.Example(features=tf.train.Features(feature={
+ 'height': _int64_feature(rows),
+ 'width': _int64_feature(cols),
+ 'depth': _int64_feature(depth),
+ 'label': _int64_feature(int(labels[index])),
+ 'image_raw': _bytes_feature(image_raw)}))
+ writer.write(example.SerializeToString())
+
+
+def main(argv):
+ # Get the data.
+ train_images_filename = input_data.maybe_download(
+ TRAIN_IMAGES, FLAGS.directory)
+ train_labels_filename = input_data.maybe_download(
+ TRAIN_LABELS, FLAGS.directory)
+ test_images_filename = input_data.maybe_download(
+ TEST_IMAGES, FLAGS.directory)
+ test_labels_filename = input_data.maybe_download(
+ TEST_LABELS, FLAGS.directory)
+
+ # Extract it into numpy arrays.
+ train_images = input_data.extract_images(train_images_filename)
+ train_labels = input_data.extract_labels(train_labels_filename)
+ test_images = input_data.extract_images(test_images_filename)
+ test_labels = input_data.extract_labels(test_labels_filename)
+
+ # Generate a validation set.
+ validation_images = train_images[:FLAGS.validation_size, :, :, :]
+ validation_labels = train_labels[:FLAGS.validation_size]
+ train_images = train_images[FLAGS.validation_size:, :, :, :]
+ train_labels = train_labels[FLAGS.validation_size:]
+
+ # Convert to Examples and write the result to TFRecords.
+ convert_to(train_images, train_labels, 'train')
+ convert_to(validation_images, validation_labels, 'validation')
+ convert_to(test_images, test_labels, 'test')
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
new file mode 100644
index 0000000000..39ce1a759b
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded.py
@@ -0,0 +1,158 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Trains the MNIST network using preloaded data in a constant.
+
+Run using bazel:
+
+bazel run -c opt \
+ <...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded
+
+or, if installed via pip:
+
+cd tensorflow/examples/how_tos/reading_data
+python fully_connected_preloaded.py
+
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import time
+
+import tensorflow.python.platform
+import numpy
+import tensorflow as tf
+
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
+
+
+# Basic model parameters as external flags.
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
+flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
+flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
+flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
+flags.DEFINE_integer('batch_size', 100, 'Batch size. '
+ 'Must divide evenly into the dataset sizes.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory to put the training data.')
+flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
+ 'for unit testing.')
+
+
+def run_training():
+ """Train MNIST for a number of epochs."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ with tf.name_scope('input'):
+ # Input data
+ input_images = tf.constant(data_sets.train.images)
+ input_labels = tf.constant(data_sets.train.labels)
+
+ image, label = tf.train.slice_input_producer(
+ [input_images, input_labels], num_epochs=FLAGS.num_epochs)
+ label = tf.cast(label, tf.int32)
+ images, labels = tf.train.batch(
+ [image, label], batch_size=FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create the op for initializing variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ sess.run(init_op)
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # And then after everything is built, start the training loop.
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ # Update the events file.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ step += 1
+
+ # Save a checkpoint periodically.
+ if (step + 1) % 1000 == 0:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
+
+
+def main(_):
+ run_training()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
new file mode 100644
index 0000000000..9a7e4e8e81
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_preloaded_var.py
@@ -0,0 +1,169 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Trains the MNIST network using preloaded data stored in a variable.
+
+Run using bazel:
+
+bazel run -c opt \
+ <...>/tensorflow/examples/how_tos/reading_data:fully_connected_preloaded_var
+
+or, if installed via pip:
+
+cd tensorflow/examples/how_tos/reading_data
+python fully_connected_preloaded_var.py
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import time
+
+import tensorflow.python.platform
+import numpy
+import tensorflow as tf
+
+from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.examples.tutorials.mnist import mnist
+
+
+# Basic model parameters as external flags.
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
+flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
+flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
+flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
+flags.DEFINE_integer('batch_size', 100, 'Batch size. '
+ 'Must divide evenly into the dataset sizes.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory to put the training data.')
+flags.DEFINE_boolean('fake_data', False, 'If true, uses fake data '
+ 'for unit testing.')
+
+
+def run_training():
+ """Train MNIST for a number of epochs."""
+ # Get the sets of images and labels for training, validation, and
+ # test on MNIST.
+ data_sets = input_data.read_data_sets(FLAGS.train_dir, FLAGS.fake_data)
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ with tf.name_scope('input'):
+ # Input data
+ images_initializer = tf.placeholder(
+ dtype=data_sets.train.images.dtype,
+ shape=data_sets.train.images.shape)
+ labels_initializer = tf.placeholder(
+ dtype=data_sets.train.labels.dtype,
+ shape=data_sets.train.labels.shape)
+ input_images = tf.Variable(
+ images_initializer, trainable=False, collections=[])
+ input_labels = tf.Variable(
+ labels_initializer, trainable=False, collections=[])
+
+ image, label = tf.train.slice_input_producer(
+ [input_images, input_labels], num_epochs=FLAGS.num_epochs)
+ label = tf.cast(label, tf.int32)
+ images, labels = tf.train.batch(
+ [image, label], batch_size=FLAGS.batch_size)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2)
+
+ # Add to the Graph the Ops for loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph the Ops that calculate and apply gradients.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # Add the Op to compare the logits to the labels during evaluation.
+ eval_correct = mnist.evaluation(logits, labels)
+
+ # Build the summary operation based on the TF collection of Summaries.
+ summary_op = tf.merge_all_summaries()
+
+ # Create a saver for writing training checkpoints.
+ saver = tf.train.Saver()
+
+ # Create the op for initializing variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running Ops on the Graph.
+ sess = tf.Session()
+
+ # Run the Op to initialize the variables.
+ sess.run(init_op)
+ sess.run(input_images.initializer,
+ feed_dict={images_initializer: data_sets.train.images})
+ sess.run(input_labels.initializer,
+ feed_dict={labels_initializer: data_sets.train.labels})
+
+ # Instantiate a SummaryWriter to output summaries and the Graph.
+ summary_writer = tf.train.SummaryWriter(FLAGS.train_dir,
+ graph_def=sess.graph_def)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ # And then after everything is built, start the training loop.
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Write the summaries and print an overview fairly often.
+ if step % 100 == 0:
+ # Print status to stdout.
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ # Update the events file.
+ summary_str = sess.run(summary_op)
+ summary_writer.add_summary(summary_str, step)
+ step += 1
+
+ # Save a checkpoint periodically.
+ if (step + 1) % 1000 == 0:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Saving')
+ saver.save(sess, FLAGS.train_dir, global_step=step)
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
+
+
+def main(_):
+ run_training()
+
+
+if __name__ == '__main__':
+ tf.app.run()
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
new file mode 100644
index 0000000000..bf1ef08c60
--- /dev/null
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -0,0 +1,197 @@
+# Copyright 2015 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Train and Eval the MNIST network.
+
+This version is like fully_connected_feed.py but uses data converted
+to a TFRecords file containing tf.train.Example protocol buffers.
+See tensorflow/g3doc/how_tos/reading_data.md#reading-from-files
+for context.
+
+YOU MUST run convert_to_records before running this (but you only need to
+run it once).
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os.path
+import time
+
+import tensorflow.python.platform
+import numpy
+import tensorflow as tf
+
+from tensorflow.examples.tutorials.mnist import mnist
+
+
+# Basic model parameters as external flags.
+flags = tf.app.flags
+FLAGS = flags.FLAGS
+flags.DEFINE_float('learning_rate', 0.01, 'Initial learning rate.')
+flags.DEFINE_integer('num_epochs', 2, 'Number of epochs to run trainer.')
+flags.DEFINE_integer('hidden1', 128, 'Number of units in hidden layer 1.')
+flags.DEFINE_integer('hidden2', 32, 'Number of units in hidden layer 2.')
+flags.DEFINE_integer('batch_size', 100, 'Batch size.')
+flags.DEFINE_string('train_dir', '/tmp/data',
+ 'Directory with the training data.')
+
+# Constants used for dealing with the files, matches convert_to_records.
+TRAIN_FILE = 'train.tfrecords'
+VALIDATION_FILE = 'validation.tfrecords'
+
+
+def read_and_decode(filename_queue):
+ reader = tf.TFRecordReader()
+ _, serialized_example = reader.read(filename_queue)
+ features = tf.parse_single_example(
+ serialized_example,
+ dense_keys=['image_raw', 'label'],
+ # Defaults are not specified since both keys are required.
+ dense_types=[tf.string, tf.int64])
+
+ # Convert from a scalar string tensor (whose single string has
+ # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
+ # [mnist.IMAGE_PIXELS].
+ image = tf.decode_raw(features['image_raw'], tf.uint8)
+ image.set_shape([mnist.IMAGE_PIXELS])
+
+ # OPTIONAL: Could reshape into a 28x28 image and apply distortions
+ # here. Since we are not applying any distortions in this
+ # example, and the next step expects the image to be flattened
+ # into a vector, we don't bother.
+
+ # Convert from [0, 255] -> [-0.5, 0.5] floats.
+ image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
+
+ # Convert label from a scalar uint8 tensor to an int32 scalar.
+ label = tf.cast(features['label'], tf.int32)
+
+ return image, label
+
+
+def inputs(train, batch_size, num_epochs):
+ """Reads input data num_epochs times.
+
+ Args:
+ train: Selects between the training (True) and validation (False) data.
+ batch_size: Number of examples per returned batch.
+ num_epochs: Number of times to read the input data, or 0/None to
+ train forever.
+
+ Returns:
+ A tuple (images, labels), where:
+ * images is a float tensor with shape [batch_size, mnist.IMAGE_PIXELS]
+ in the range [-0.5, 0.5].
+ * labels is an int32 tensor with shape [batch_size] with the true label,
+ a number in the range [0, mnist.NUM_CLASSES).
+ Note that an tf.train.QueueRunner is added to the graph, which
+ must be run using e.g. tf.train.start_queue_runners().
+ """
+ if not num_epochs: num_epochs = None
+ filename = os.path.join(FLAGS.train_dir,
+ TRAIN_FILE if train else VALIDATION_FILE)
+
+ with tf.name_scope('input'):
+ filename_queue = tf.train.string_input_producer(
+ [filename], num_epochs=num_epochs)
+
+ # Even when reading in multiple threads, share the filename
+ # queue.
+ image, label = read_and_decode(filename_queue)
+
+ # Shuffle the examples and collect them into batch_size batches.
+ # (Internally uses a RandomShuffleQueue.)
+ # We run this in two threads to avoid being a bottleneck.
+ images, sparse_labels = tf.train.shuffle_batch(
+ [image, label], batch_size=batch_size, num_threads=2,
+ capacity=1000 + 3 * batch_size,
+ # Ensures a minimum amount of shuffling of examples.
+ min_after_dequeue=1000)
+
+ return images, sparse_labels
+
+
+def run_training():
+ """Train MNIST for a number of steps."""
+
+ # Tell TensorFlow that the model will be built into the default Graph.
+ with tf.Graph().as_default():
+ # Input images and labels.
+ images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
+ num_epochs=FLAGS.num_epochs)
+
+ # Build a Graph that computes predictions from the inference model.
+ logits = mnist.inference(images,
+ FLAGS.hidden1,
+ FLAGS.hidden2)
+
+ # Add to the Graph the loss calculation.
+ loss = mnist.loss(logits, labels)
+
+ # Add to the Graph operations that train the model.
+ train_op = mnist.training(loss, FLAGS.learning_rate)
+
+ # The op for initializing the variables.
+ init_op = tf.initialize_all_variables()
+
+ # Create a session for running operations in the Graph.
+ sess = tf.Session()
+
+ # Initialize the variables (the trained variables and the
+ # epoch counter).
+ sess.run(init_op)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model. The return values are
+ # the activations from the `train_op` (which is
+ # discarded) and the `loss` op. To inspect the values
+ # of your ops or variables, you may include them in
+ # the list passed to sess.run() and the value tensors
+ # will be returned in the tuple from the call.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Print an overview fairly often.
+ if step % 100 == 0:
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ duration))
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
+
+
+def main(_):
+ run_training()
+
+
+if __name__ == '__main__':
+ tf.app.run()