diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-15 17:32:50 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-15 17:39:26 -0800 |
commit | 9648f8040a559f6cf9bbe0501ba96f2b2c2864b1 (patch) | |
tree | 57dc6e959e0a534622eaf392ee43b7691378b10e /tensorflow/examples/how_tos | |
parent | 5b5445b9a7aa2664a90c4fc946ecf268c971425b (diff) |
Automated g4 rollback of changelist 179258973
PiperOrigin-RevId: 179260538
Diffstat (limited to 'tensorflow/examples/how_tos')
-rw-r--r-- | tensorflow/examples/how_tos/reading_data/fully_connected_reader.py | 125 |
1 files changed, 68 insertions, 57 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 9db8835d92..a9ed02dd1a 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -45,7 +45,9 @@ TRAIN_FILE = 'train.tfrecords' VALIDATION_FILE = 'validation.tfrecords' -def decode(serialized_example): +def read_and_decode(filename_queue): + reader = tf.TFRecordReader() + _, serialized_example = reader.read(filename_queue) features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. @@ -58,26 +60,22 @@ def decode(serialized_example): # length mnist.IMAGE_PIXELS) to a uint8 tensor with shape # [mnist.IMAGE_PIXELS]. image = tf.decode_raw(features['image_raw'], tf.uint8) - image.set_shape((mnist.IMAGE_PIXELS)) + image.set_shape([mnist.IMAGE_PIXELS]) - # Convert label from a scalar uint8 tensor to an int32 scalar. - label = tf.cast(features['label'], tf.int32) - - return image, label - -def augment(image, label): # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this # example, and the next step expects the image to be flattened # into a vector, we don't bother. - return image, label -def normalize(image, label): # Convert from [0, 255] -> [-0.5, 0.5] floats. image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 + # Convert label from a scalar uint8 tensor to an int32 scalar. + label = tf.cast(features['label'], tf.int32) + return image, label + def inputs(train, batch_size, num_epochs): """Reads input data num_epochs times. @@ -93,32 +91,31 @@ def inputs(train, batch_size, num_epochs): in the range [-0.5, 0.5]. * labels is an int32 tensor with shape [batch_size] with the true label, a number in the range [0, mnist.NUM_CLASSES). - - This function creates a one_shot_iterator, meaning that it will only iterate - over the dataset once. On the other hand there is no special initialization - required. + Note that an tf.train.QueueRunner is added to the graph, which + must be run using e.g. tf.train.start_queue_runners(). """ if not num_epochs: num_epochs = None filename = os.path.join(FLAGS.train_dir, TRAIN_FILE if train else VALIDATION_FILE) with tf.name_scope('input'): - # TFRecordDataset opens a protobuf and reads entries line by line - # could also be [list, of, filenames] - dataset = tf.data.TFRecordDataset(filename) - dataset = dataset.repeat(num_epochs) + filename_queue = tf.train.string_input_producer( + [filename], num_epochs=num_epochs) - # map takes a python function and applies it to every sample - dataset = dataset.map(decode) - dataset = dataset.map(augment) - dataset = dataset.map(normalize) + # Even when reading in multiple threads, share the filename + # queue. + image, label = read_and_decode(filename_queue) - #the parameter is the queue size - dataset = dataset.shuffle(1000 + 3 * batch_size) - dataset = dataset.batch(batch_size) + # Shuffle the examples and collect them into batch_size batches. + # (Internally uses a RandomShuffleQueue.) + # We run this in two threads to avoid being a bottleneck. + images, sparse_labels = tf.train.shuffle_batch( + [image, label], batch_size=batch_size, num_threads=2, + capacity=1000 + 3 * batch_size, + # Ensures a minimum amount of shuffling of examples. + min_after_dequeue=1000) - iterator = dataset.make_one_shot_iterator() - return iterator.get_next() + return images, sparse_labels def run_training(): @@ -127,16 +124,16 @@ def run_training(): # Tell TensorFlow that the model will be built into the default Graph. with tf.Graph().as_default(): # Input images and labels. - image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size, - num_epochs=FLAGS.num_epochs) + images, labels = inputs(train=True, batch_size=FLAGS.batch_size, + num_epochs=FLAGS.num_epochs) # Build a Graph that computes predictions from the inference model. - logits = mnist.inference(image_batch, + logits = mnist.inference(images, FLAGS.hidden1, FLAGS.hidden2) # Add to the Graph the loss calculation. - loss = mnist.loss(logits, label_batch) + loss = mnist.loss(logits, labels) # Add to the Graph operations that train the model. train_op = mnist.training(loss, FLAGS.learning_rate) @@ -146,33 +143,47 @@ def run_training(): tf.local_variables_initializer()) # Create a session for running operations in the Graph. - with tf.Session() as sess: - # Initialize the variables (the trained variables and the - # epoch counter). - sess.run(init_op) - try: - step = 0 - while True: #train until OutOfRangeError - start_time = time.time() - - # Run one step of the model. The return values are - # the activations from the `train_op` (which is - # discarded) and the `loss` op. To inspect the values - # of your ops or variables, you may include them in - # the list passed to sess.run() and the value tensors - # will be returned in the tuple from the call. - _, loss_value = sess.run([train_op, loss]) - - duration = time.time() - start_time - - # Print an overview fairly often. - if step % 100 == 0: - print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, + sess = tf.Session() + + # Initialize the variables (the trained variables and the + # epoch counter). + sess.run(init_op) + + # Start input enqueue threads. + coord = tf.train.Coordinator() + threads = tf.train.start_queue_runners(sess=sess, coord=coord) + + try: + step = 0 + while not coord.should_stop(): + start_time = time.time() + + # Run one step of the model. The return values are + # the activations from the `train_op` (which is + # discarded) and the `loss` op. To inspect the values + # of your ops or variables, you may include them in + # the list passed to sess.run() and the value tensors + # will be returned in the tuple from the call. + _, loss_value = sess.run([train_op, loss]) + + duration = time.time() - start_time + + # Print an overview fairly often. + if step % 100 == 0: + print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value, duration)) - step += 1 - except tf.errors.OutOfRangeError: - print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) - + step += 1 + except tf.errors.OutOfRangeError: + print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step)) + finally: + # When done, ask the threads to stop. + coord.request_stop() + + # Wait for threads to finish. + coord.join(threads) + sess.close() + + def main(_): run_training() |