From e09808e0e4d05ffe4be97f563f0f14ae2ebb20dd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 28 Jun 2016 23:51:09 -0800 Subject: Fixing broken code in how_tos for fully_connected_reader. - Added initialize_local_variables() - Updated the code that generates the dataset since it no longer worked either. - This in turn required adding a "reshape" option that controls whether or not the imaged data is flattened. Change: 126165145 --- .../how_tos/reading_data/convert_to_records.py | 42 ++++++++-------------- .../how_tos/reading_data/fully_connected_reader.py | 3 +- 2 files changed, 17 insertions(+), 28 deletions(-) (limited to 'tensorflow/examples/how_tos') diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py index ee558f5b19..2e3035731a 100644 --- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py +++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py @@ -21,9 +21,11 @@ from __future__ import print_function import os import numpy import tensorflow as tf -from tensorflow.examples.tutorials.mnist import input_data +from tensorflow.contrib.learn.python.learn.datasets import mnist +SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/' + TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames TRAIN_LABELS = 'train-labels-idx1-ubyte.gz' TEST_IMAGES = 't10k-images-idx3-ubyte.gz' @@ -47,10 +49,13 @@ def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) -def convert_to(images, labels, name): - num_examples = labels.shape[0] +def convert_to(data_set, name): + images = data_set.images + labels = data_set.labels + num_examples = data_set.num_examples + if images.shape[0] != num_examples: - raise ValueError("Images size %d does not match label size %d." % + raise ValueError('Images size %d does not match label size %d.' % (images.shape[0], num_examples)) rows = images.shape[1] cols = images.shape[2] @@ -73,31 +78,14 @@ def convert_to(images, labels, name): def main(argv): # Get the data. - train_images_filename = input_data.maybe_download( - TRAIN_IMAGES, FLAGS.directory) - train_labels_filename = input_data.maybe_download( - TRAIN_LABELS, FLAGS.directory) - test_images_filename = input_data.maybe_download( - TEST_IMAGES, FLAGS.directory) - test_labels_filename = input_data.maybe_download( - TEST_LABELS, FLAGS.directory) - - # Extract it into numpy arrays. - train_images = input_data.extract_images(train_images_filename) - train_labels = input_data.extract_labels(train_labels_filename) - test_images = input_data.extract_images(test_images_filename) - test_labels = input_data.extract_labels(test_labels_filename) - - # Generate a validation set. - validation_images = train_images[:FLAGS.validation_size, :, :, :] - validation_labels = train_labels[:FLAGS.validation_size] - train_images = train_images[FLAGS.validation_size:, :, :, :] - train_labels = train_labels[FLAGS.validation_size:] + data_sets = mnist.read_data_sets(FLAGS.directory, + dtype=tf.uint8, + reshape=False) # Convert to Examples and write the result to TFRecords. - convert_to(train_images, train_labels, 'train') - convert_to(validation_images, validation_labels, 'validation') - convert_to(test_images, test_labels, 'test') + convert_to(data_sets.train, 'train') + convert_to(data_sets.validation, 'validation') + convert_to(data_sets.test, 'test') if __name__ == '__main__': diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py index 648825e6e9..bdd821373f 100644 --- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py +++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py @@ -146,7 +146,8 @@ def run_training(): train_op = mnist.training(loss, FLAGS.learning_rate) # The op for initializing the variables. - init_op = tf.initialize_all_variables() + init_op = tf.group(tf.initialize_all_variables(), + tf.initialize_local_variables()) # Create a session for running operations in the Graph. sess = tf.Session() -- cgit v1.2.3