aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples/how_tos
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-06-28 23:51:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-06-29 01:04:01 -0700
commite09808e0e4d05ffe4be97f563f0f14ae2ebb20dd (patch)
treecc0d06cb82aba704f077d8fffbb645bd3aeb74c4 /tensorflow/examples/how_tos
parent41efdfbf7711d34e45ccdad2bbd6fdf0a1515cbe (diff)
Fixing broken code in how_tos for fully_connected_reader.
- Added initialize_local_variables() - Updated the code that generates the dataset since it no longer worked either. - This in turn required adding a "reshape" option that controls whether or not the imaged data is flattened. Change: 126165145
Diffstat (limited to 'tensorflow/examples/how_tos')
-rw-r--r--tensorflow/examples/how_tos/reading_data/convert_to_records.py42
-rw-r--r--tensorflow/examples/how_tos/reading_data/fully_connected_reader.py3
2 files changed, 17 insertions, 28 deletions
diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
index ee558f5b19..2e3035731a 100644
--- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
@@ -21,9 +21,11 @@ from __future__ import print_function
import os
import numpy
import tensorflow as tf
-from tensorflow.examples.tutorials.mnist import input_data
+from tensorflow.contrib.learn.python.learn.datasets import mnist
+SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
+
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz' # MNIST filenames
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
@@ -47,10 +49,13 @@ def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
-def convert_to(images, labels, name):
- num_examples = labels.shape[0]
+def convert_to(data_set, name):
+ images = data_set.images
+ labels = data_set.labels
+ num_examples = data_set.num_examples
+
if images.shape[0] != num_examples:
- raise ValueError("Images size %d does not match label size %d." %
+ raise ValueError('Images size %d does not match label size %d.' %
(images.shape[0], num_examples))
rows = images.shape[1]
cols = images.shape[2]
@@ -73,31 +78,14 @@ def convert_to(images, labels, name):
def main(argv):
# Get the data.
- train_images_filename = input_data.maybe_download(
- TRAIN_IMAGES, FLAGS.directory)
- train_labels_filename = input_data.maybe_download(
- TRAIN_LABELS, FLAGS.directory)
- test_images_filename = input_data.maybe_download(
- TEST_IMAGES, FLAGS.directory)
- test_labels_filename = input_data.maybe_download(
- TEST_LABELS, FLAGS.directory)
-
- # Extract it into numpy arrays.
- train_images = input_data.extract_images(train_images_filename)
- train_labels = input_data.extract_labels(train_labels_filename)
- test_images = input_data.extract_images(test_images_filename)
- test_labels = input_data.extract_labels(test_labels_filename)
-
- # Generate a validation set.
- validation_images = train_images[:FLAGS.validation_size, :, :, :]
- validation_labels = train_labels[:FLAGS.validation_size]
- train_images = train_images[FLAGS.validation_size:, :, :, :]
- train_labels = train_labels[FLAGS.validation_size:]
+ data_sets = mnist.read_data_sets(FLAGS.directory,
+ dtype=tf.uint8,
+ reshape=False)
# Convert to Examples and write the result to TFRecords.
- convert_to(train_images, train_labels, 'train')
- convert_to(validation_images, validation_labels, 'validation')
- convert_to(test_images, test_labels, 'test')
+ convert_to(data_sets.train, 'train')
+ convert_to(data_sets.validation, 'validation')
+ convert_to(data_sets.test, 'test')
if __name__ == '__main__':
diff --git a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
index 648825e6e9..bdd821373f 100644
--- a/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
+++ b/tensorflow/examples/how_tos/reading_data/fully_connected_reader.py
@@ -146,7 +146,8 @@ def run_training():
train_op = mnist.training(loss, FLAGS.learning_rate)
# The op for initializing the variables.
- init_op = tf.initialize_all_variables()
+ init_op = tf.group(tf.initialize_all_variables(),
+ tf.initialize_local_variables())
# Create a session for running operations in the Graph.
sess = tf.Session()