aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/how_tos
diff options
context:
space:
mode:
authorGravatar Yifei Feng <yifeif@google.com>2018-01-25 12:02:36 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-25 12:07:22 -0800
commit351c0a533a111636333b4ebeede16485cf679ca9 (patch)
treea0786bc9a8fe7432d69d8095b10586e3ef515b93 /tensorflow/examples/how_tos
parenta8c4e8d96de7c0978851a5f9718bbd6b8056d862 (diff)
Add C0330 bad-continuation check to pylint.
PiperOrigin-RevId: 183270896
Diffstat (limited to 'tensorflow/examples/how_tos')
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py49
1 files changed, 21 insertions, 28 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index fa4c1c0da5..461fb1c517 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-
"""Train and Eval the MNIST network.
This version is like fully_connected_feed.py but uses data converted
@@ -65,6 +64,7 @@ def decode(serialized_example):
return image, label
+
def augment(image, label):
# OPTIONAL: Could reshape into a 28x28 image and apply distortions
# here. Since we are not applying any distortions in this
@@ -72,12 +72,14 @@ def augment(image, label):
# into a vector, we don't bother.
return image, label
+
def normalize(image, label):
# Convert from [0, 255] -> [-0.5, 0.5] floats.
image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
return image, label
+
def inputs(train, batch_size, num_epochs):
"""Reads input data num_epochs times.
@@ -98,9 +100,10 @@ def inputs(train, batch_size, num_epochs):
over the dataset once. On the other hand there is no special initialization
required.
"""
- if not num_epochs: num_epochs = None
- filename = os.path.join(FLAGS.train_dir,
- TRAIN_FILE if train else VALIDATION_FILE)
+ if not num_epochs:
+ num_epochs = None
+ filename = os.path.join(FLAGS.train_dir, TRAIN_FILE
+ if train else VALIDATION_FILE)
with tf.name_scope('input'):
# TFRecordDataset opens a protobuf and reads entries line by line
@@ -127,13 +130,11 @@ def run_training():
# Tell TensorFlow that the model will be built into the default Graph.
with tf.Graph().as_default():
# Input images and labels.
- image_batch, label_batch = inputs(train=True, batch_size=FLAGS.batch_size,
- num_epochs=FLAGS.num_epochs)
+ image_batch, label_batch = inputs(
+ train=True, batch_size=FLAGS.batch_size, num_epochs=FLAGS.num_epochs)
# Build a Graph that computes predictions from the inference model.
- logits = mnist.inference(image_batch,
- FLAGS.hidden1,
- FLAGS.hidden2)
+ logits = mnist.inference(image_batch, FLAGS.hidden1, FLAGS.hidden2)
# Add to the Graph the loss calculation.
loss = mnist.loss(logits, label_batch)
@@ -152,7 +153,7 @@ def run_training():
sess.run(init_op)
try:
step = 0
- while True: #train until OutOfRangeError
+ while True: #train until OutOfRangeError
start_time = time.time()
# Run one step of the model. The return values are
@@ -168,10 +169,12 @@ def run_training():
# Print an overview fairly often.
if step % 100 == 0:
print('Step %d: loss = %.2f (%.3f sec)' % (step, loss_value,
- duration))
+ duration))
step += 1
except tf.errors.OutOfRangeError:
- print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs, step))
+ print('Done training for %d epochs, %d steps.' % (FLAGS.num_epochs,
+ step))
+
def main(_):
run_training()
@@ -183,37 +186,27 @@ if __name__ == '__main__':
'--learning_rate',
type=float,
default=0.01,
- help='Initial learning rate.'
- )
+ help='Initial learning rate.')
parser.add_argument(
'--num_epochs',
type=int,
default=2,
- help='Number of epochs to run trainer.'
- )
+ help='Number of epochs to run trainer.')
parser.add_argument(
'--hidden1',
type=int,
default=128,
- help='Number of units in hidden layer 1.'
- )
+ help='Number of units in hidden layer 1.')
parser.add_argument(
'--hidden2',
type=int,
default=32,
- help='Number of units in hidden layer 2.'
- )
- parser.add_argument(
- '--batch_size',
- type=int,
- default=100,
- help='Batch size.'
- )
+ help='Number of units in hidden layer 2.')
+ parser.add_argument('--batch_size', type=int, default=100, help='Batch size.')
parser.add_argument(
'--train_dir',
type=str,
default='/tmp/data',
- help='Directory with the training data.'
- )
+ help='Directory with the training data.')
FLAGS, unparsed = parser.parse_known_args()
tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)