aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Derek Murray <mrry@google.com>2018-02-21 16:20:27 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-21 16:24:06 -0800
commit5a474209be6a1db6dc81080b9a5f965b28dfb88e (patch)
treefdf606241f87436da06f4dbdb860a606e3beecf4 /tensorflow/examples
parent83486cb183099c3dc2dcfd036ded4e6526761918 (diff)
Fix lint errors and improve docs in fully_connected_reader.py.
PiperOrigin-RevId: 186537109
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index 461fb1c517..307eede5c0 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -1,4 +1,4 @@
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -45,6 +45,7 @@ VALIDATION_FILE = 'validation.tfrecords'
def decode(serialized_example):
+ """Parses an image and label from the given `serialized_example`."""
features = tf.parse_single_example(
serialized_example,
# Defaults are not specified since both keys are required.
@@ -66,6 +67,7 @@ def decode(serialized_example):
def augment(image, label):
+ """Placeholder for data augmentation."""
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
# example, and the next step expects the image to be flattened
@@ -74,9 +76,8 @@ def augment(image, label):
def normalize(image, label):
- # Convert from [0, 255] -> [-0.5, 0.5] floats.
+ """Convert `image` from [0, 255] -> [-0.5, 0.5] floats."""
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
-
return image, label
@@ -106,18 +107,23 @@ def inputs(train, batch_size, num_epochs):
if train else VALIDATION_FILE)
with tf.name_scope('input'):
- # TFRecordDataset opens a protobuf and reads entries line by line
- # could also be [list, of, filenames]
+ # TFRecordDataset opens a binary file and reads one record at a time.
+ # `filename` could also be a list of filenames, which will be read in order.
dataset = tf.data.TFRecordDataset(filename)
- dataset = dataset.repeat(num_epochs)
- # map takes a python function and applies it to every sample
+ # The map transformation takes a function and applies it to every element
+ # of the dataset.
dataset = dataset.map(decode)
dataset = dataset.map(augment)
dataset = dataset.map(normalize)
- #the parameter is the queue size
+ # The shuffle transformation uses a finite-sized buffer to shuffle elements
+ # in memory. The parameter is the number of elements in the buffer. For
+ # completely uniform shuffling, set the parameter to be the same as the
+ # number of elements in the dataset.
dataset = dataset.shuffle(1000 + 3 * batch_size)
+
+ dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size)
iterator = dataset.make_one_shot_iterator()
@@ -153,7 +159,7 @@ def run_training():
sess.run(init_op)
try:
step = 0
- while True: #train until OutOfRangeError
+ while True: # Train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are