diff options
author | Derek Murray <mrry@google.com> | 2018-02-21 16:20:27 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-21 16:24:06 -0800 |
commit | 5a474209be6a1db6dc81080b9a5f965b28dfb88e (patch) | |
tree | fdf606241f87436da06f4dbdb860a606e3beecf4 /tensorflow/examples | |
parent | 83486cb183099c3dc2dcfd036ded4e6526761918 (diff) |
Fix lint errors and improve docs in fully_connected_reader.py.
PiperOrigin-RevId: 186537109
Diffstat (limited to 'tensorflow/examples')
-rw-r--r-- | tensorflow/examples/how_tos/reading_data/fully_connected_reader.py | 24 |
1 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 461fb1c517..307eede5c0 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -1,4 +1,4 @@ -# Copyright 2015 The TensorFlow Authors. All Rights Reserved. +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -45,6 +45,7 @@ VALIDATION_FILE = 'validation.tfrecords' def decode(serialized_example): + """Parses an image and label from the given `serialized_example`.""" features = tf.parse_single_example( serialized_example, # Defaults are not specified since both keys are required. @@ -66,6 +67,7 @@ def decode(serialized_example): def augment(image, label): + """Placeholder for data augmentation.""" # OPTIONAL: Could reshape into a 28x28 image and apply distortions # here. Since we are not applying any distortions in this # example, and the next step expects the image to be flattened @@ -74,9 +76,8 @@ def augment(image, label): def normalize(image, label): - # Convert from [0, 255] -> [-0.5, 0.5] floats. + """Convert `image` from [0, 255] -> [-0.5, 0.5] floats.""" image = tf.cast(image, tf.float32) * (1. / 255) - 0.5 - return image, label @@ -106,18 +107,23 @@ def inputs(train, batch_size, num_epochs): if train else VALIDATION_FILE) with tf.name_scope('input'): - # TFRecordDataset opens a protobuf and reads entries line by line - # could also be [list, of, filenames] + # TFRecordDataset opens a binary file and reads one record at a time. + # `filename` could also be a list of filenames, which will be read in order. dataset = tf.data.TFRecordDataset(filename) - dataset = dataset.repeat(num_epochs) - # map takes a python function and applies it to every sample + # The map transformation takes a function and applies it to every element + # of the dataset. dataset = dataset.map(decode) dataset = dataset.map(augment) dataset = dataset.map(normalize) - #the parameter is the queue size + # The shuffle transformation uses a finite-sized buffer to shuffle elements + # in memory. The parameter is the number of elements in the buffer. For + # completely uniform shuffling, set the parameter to be the same as the + # number of elements in the dataset. dataset = dataset.shuffle(1000 + 3 * batch_size) + + dataset = dataset.repeat(num_epochs) dataset = dataset.batch(batch_size) iterator = dataset.make_one_shot_iterator() @@ -153,7 +159,7 @@ def run_training(): sess.run(init_op) try: step = 0 - while True: #train until OutOfRangeError + while True: # Train until OutOfRangeError start_time = time.time() # Run one step of the model. The return values are |