diff options
author | FirefoxMetzger <S.Wallkoetter@gmx.de> | 2017-12-11 01:58:36 +0100 |
---|---|---|
committer | Shanqing Cai <cais@google.com> | 2017-12-10 19:58:36 -0500 |
commit | 6b2f1d6394f5333fae172f898406f92a86874828 (patch) | |
tree | 8f5b001f1e874b2048ad55c223698af5b61cfe88 /tensorflow/examples | |
parent | ec57ca65c46a45a24454c6a2d33bcce379d56627 (diff) |
update how_tos/reading_data to use Dataset API (#14751)
* updated reading_data to use Dataset
Since the Dataset API moved from .contrib.data into .data (core) update
the MNIST example to use Dataset over queues.
Diffstat (limited to 'tensorflow/examples')
-rw-r--r-- | tensorflow/examples/how_tos/reading_data/fully_connected_reader.py | 125 |
1 files changed, 57 insertions, 68 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index a9ed02dd1a..9db8835d92 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -45,9 +45,7 @@ TRAIN_FILE = 'train.tfrecords' VALIDATION_FILE = 'validation.tfrecords' -def read_and_decode(filename_queue): - reader = tf.TFRecordReader() - _, serialized_example = reader.read(filename_queue) +def decode(serialized_example): features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. @@ -60,22 +58,26 @@ def read_and_decode(filename_queue): # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape # [mnist.IMAGE_PIXELS]. image = tf.decode_raw(features['image_raw'], tf.uint8) - image.set_shape([mnist.IMAGE_PIXELS]) + image.set_shape((mnist.IMAGE_PIXELS)) + # Convert label from a scalar uint8 tensor to an int32 scalar. + label = tf.cast(features['label'], tf.int32) + + return image, label + +def augment(image, label): # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this # example, and the next step expects the image to be flattened # into a vector, we don't bother. + return image, label +def normalize(image, label): # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 - # Convert label from a scalar uint8 tensor to an int32 scalar. - label = tf.cast(features['label'], tf.int32) - return image, label - def inputs(train, batch_size, num_epochs): """Reads input data num_epochs times. @@ -91,31 +93,32 @@ def inputs(train, batch_size, num_epochs): in the range [-0.5, 0.5]. * labels is an int32 tensor with shape [batch_size] with the true label, a number in the range [0, mnist.NUM_CLASSES). - Note that an tf.train.QueueRunner is added to the graph, which - must be run using e.g. tf.train.start_queue_runners(). + + This function creates a one_shot_iterator, meaning that it will only iterate + over the dataset once. On the other hand there is no special initialization + required. """ if not num_epochs: num_epochs = None filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE) with tf.name_scope('input'): - filename_queue = tf.train.string_input_producer( - [filename], num_epochs=num_epochs) + # TFRecordDataset opens a protobuf and reads entries line by line + # could also be [list, of, filenames] + dataset = tf.data.TFRecordDataset(filename) + dataset = dataset.repeat(num_epochs) - # Even when reading in multiple threads, share the filename - # queue. - image, label = read_and_decode(filename_queue) + # map takes a python function and applies it to every sample + dataset = dataset.map(decode) + dataset = dataset.map(augment) + dataset = dataset.map(normalize) - # Shuffle the examples and collect them into batch_size batches. - # (Internally uses a RandomShuffleQueue.) - # We run this in two threads to avoid being a bottleneck. - images, sparse_labels = tf.train.shuffle_batch( - [image, label], batch_size=batch_size, num_threads=2, - capacity=1000 + 3 * batch_size, - # Ensures a minimum amount of shuffling of examples. - min_after_dequeue=1000) + #the parameter is the queue size + dataset = dataset.shuffle(1000 + 3 * batch_size) + dataset = dataset.batch(batch_size) - return images, sparse_labels + iterator = dataset.make_one_shot_iterator() + return iterator.get_next() def run_training(): @@ -124,16 +127,16 @@ def run_training(): # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Input images and labels. - images, labels = inputs(train=True, batch_size=FLAGS.batch_size, - num_epochs=FLAGS.num_epochs) + image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size, + num_epochs=FLAGS.num_epochs) # Build a Graph that computes predictions from the inference model. - logits = mnist.inference(images, + logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the loss calculation. - loss = mnist.loss(logits, labels) + loss = mnist.loss(logits, label_batch) # Add to the Graph operations that train the model. train_op = mnist.training(loss, FLAGS.learning_rate) @@ -143,47 +146,33 @@ def run_training(): tf.local_variables_initializer()) # Create a session for running operations in the Graph. - sess = tf.Session() - - # Initialize the variables (the trained variables and the - # epoch counter). - sess.run(init_op) - - # Start input enqueue threads. - coord = tf.train.Coordinator() - threads = tf.train.start_queue_runners(sess=sess, coord=coord) - - try: - step = 0 - while not coord.should_stop(): - start_time = time.time() - - # Run one step of the model. The return values are - # the activations from the `train_op` (which is - # discarded) and the `loss` op. To inspect the values - # of your ops or variables, you may include them in - # the list passed to sess.run() and the value tensors - # will be returned in the tuple from the call. - _, loss_value = sess.run([train_op, loss]) - - duration = time.time() - start_time - - # Print an overview fairly often. - if step % 100 == 0: - print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, + with tf.Session() as sess: + # Initialize the variables (the trained variables and the + # epoch counter). + sess.run(init_op) + try: + step = 0 + while True: #train until OutOfRangeError + start_time = time.time() + + # Run one step of the model. The return values are + # the activations from the `train_op` (which is + # discarded) and the `loss` op. To inspect the values + # of your ops or variables, you may include them in + # the list passed to sess.run() and the value tensors + # will be returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss]) + + duration = time.time() - start_time + + # Print an overview fairly often. + if step % 100 == 0: + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) - step += 1 - except tf.errors.OutOfRangeError: - print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) - finally: - # When done, ask the threads to stop. - coord.request_stop() - - # Wait for threads to finish. - coord.join(threads) - sess.close() - - + step += 1 + except tf.errors.OutOfRangeError: + print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) + def main(_): run_training() |