aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/how_tos
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-15 17:32:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-15 17:39:26 -0800
commit9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (patch)
tree57dc6e959e0a534622eaf392ee43b7691378b10e /tensorflow/examples/how_tos
parent5b5445b9a7aa2664a90c4fc946ecf268c971425b (diff)
Automated g4 rollback of changelist 179258973
PiperOrigin-RevId: 179260538
Diffstat (limited to 'tensorflow/examples/how_tos')
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py125
1 files changed, 68 insertions, 57 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index 9db8835d92..a9ed02dd1a 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -45,7 +45,9 @@ TRAIN_FILE = 'train.tfrecords'
VALIDATION_FILE = 'validation.tfrecords'
-def decode(serialized_example):
+def read_and_decode(filename_queue):
+ reader = tf.TFRecordReader()
+ _, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
@@ -58,26 +60,22 @@ def decode(serialized_example):
# length mnist.IMAGE_PIXELS) to a uint8 tensor with shape
# [mnist.IMAGE_PIXELS].
image = tf.decode_raw(features['image_raw'], tf.uint8)
- image.set_shape((mnist.IMAGE_PIXELS))
+ image.set_shape([mnist.IMAGE_PIXELS])
- # Convert label from a scalar uint8 tensor to an int32 scalar.
- label = tf.cast(features['label'], tf.int32)
-
- return image, label
-
-def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
# into a vector, we don't bother.
- return image, label
-def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
+ # Convert label from a scalar uint8 tensor to an int32 scalar.
+ label = tf.cast(features['label'], tf.int32)
+
return image, label
+
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
@@ -93,32 +91,31 @@ def inputs(train, batch_size, num_epochs):
in the range [-0.5, 0.5].
* labels is an int32 tensor with shape [batch_size] with the true label,
a number in the range [0, mnist.NUM_CLASSES).
-
- This function creates a one_shot_iterator, meaning that it will only iterate
- over the dataset once. On the other hand there is no special initialization
- required.
+ Note that an tf.train.QueueRunner is added to the graph, which
+ must be run using e.g. tf.train.start_queue_runners().
"""
if not num_epochs: num_epochs = None
filename = os.path.join(FLAGS.train_dir,
TRAIN_FILE if train else VALIDATION_FILE)
with tf.name_scope('input'):
- # TFRecordDataset opens a protobuf and reads entries line by line
- # could also be [list, of, filenames]
- dataset = tf.data.TFRecordDataset(filename)
- dataset = dataset.repeat(num_epochs)
+ filename_queue = tf.train.string_input_producer(
+ [filename], num_epochs=num_epochs)
- # map takes a python function and applies it to every sample
- dataset = dataset.map(decode)
- dataset = dataset.map(augment)
- dataset = dataset.map(normalize)
+ # Even when reading in multiple threads, share the filename
+ # queue.
+ image, label = read_and_decode(filename_queue)
- #the parameter is the queue size
- dataset = dataset.shuffle(1000 + 3 * batch_size)
- dataset = dataset.batch(batch_size)
+ # Shuffle the examples and collect them into batch_size batches.
+ # (Internally uses a RandomShuffleQueue.)
+ # We run this in two threads to avoid being a bottleneck.
+ images, sparse_labels = tf.train.shuffle_batch(
+ [image, label], batch_size=batch_size, num_threads=2,
+ capacity=1000 + 3 * batch_size,
+ # Ensures a minimum amount of shuffling of examples.
+ min_after_dequeue=1000)
- iterator = dataset.make_one_shot_iterator()
- return iterator.get_next()
+ return images, sparse_labels
def run_training():
@@ -127,16 +124,16 @@ def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
- image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
- num_epochs=FLAGS.num_epochs)
+ images, labels = inputs(train=True, batch_size=FLAGS.batch_size,
+ num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(image_batch,
+ logits = mnist.inference(images,
FLAGS.hidden1,
FLAGS.hidden2)
# Add to the Graph the loss calculation.
- loss = mnist.loss(logits, label_batch)
+ loss = mnist.loss(logits, labels)
# Add to the Graph operations that train the model.
train_op = mnist.training(loss, FLAGS.learning_rate)
@@ -146,33 +143,47 @@ def run_training():
tf.local_variables_initializer())
# Create a session for running operations in the Graph.
- with tf.Session() as sess:
- # Initialize the variables (the trained variables and the
- # epoch counter).
- sess.run(init_op)
- try:
- step = 0
- while True: #train until OutOfRangeError
- start_time = time.time()
-
- # Run one step of the model. The return values are
- # the activations from the `train_op` (which is
- # discarded) and the `loss` op. To inspect the values
- # of your ops or variables, you may include them in
- # the list passed to sess.run() and the value tensors
- # will be returned in the tuple from the call.
- _, loss_value = sess.run([train_op, loss])
-
- duration = time.time() - start_time
-
- # Print an overview fairly often.
- if step % 100 == 0:
- print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
+ sess = tf.Session()
+
+ # Initialize the variables (the trained variables and the
+ # epoch counter).
+ sess.run(init_op)
+
+ # Start input enqueue threads.
+ coord = tf.train.Coordinator()
+ threads = tf.train.start_queue_runners(sess=sess, coord=coord)
+
+ try:
+ step = 0
+ while not coord.should_stop():
+ start_time = time.time()
+
+ # Run one step of the model. The return values are
+ # the activations from the `train_op` (which is
+ # discarded) and the `loss` op. To inspect the values
+ # of your ops or variables, you may include them in
+ # the list passed to sess.run() and the value tensors
+ # will be returned in the tuple from the call.
+ _, loss_value = sess.run([train_op, loss])
+
+ duration = time.time() - start_time
+
+ # Print an overview fairly often.
+ if step % 100 == 0:
+ print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
duration))
- step += 1
- except tf.errors.OutOfRangeError:
- print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
-
+ step += 1
+ except tf.errors.OutOfRangeError:
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ finally:
+ # When done, ask the threads to stop.
+ coord.request_stop()
+
+ # Wait for threads to finish.
+ coord.join(threads)
+ sess.close()
+
+
def main(_):
run_training()