aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/models/image
diff options
context:
space:
mode:
authorGravatar Martin Wicke <wicke@google.com>2016-07-25 13:48:16 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-07-25 15:02:34 -0700
commit21716d8f6e175cd6e8cd97a84e48497574268b0c (patch)
tree3345202e7a812cc9572beb24fc01732696a4140d /tensorflow/models/image
parented281973d66d0030e58a77a05821bbb88627f5bd (diff)
Merge changes from github.
Change: 128401884
Diffstat (limited to 'tensorflow/models/image')
-rw-r--r--tensorflow/models/image/mnist/convolutional.py4
1 files changed, 2 insertions, 2 deletions
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index 1893e68121..26e4a6ac8f 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -82,10 +82,10 @@ def extract_data(filename, num_images):
print('Extracting', filename)
with gzip.open(filename) as bytestream:
bytestream.read(16)
- buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images)
+ buf = bytestream.read(IMAGE_SIZE * IMAGE_SIZE * num_images * NUM_CHANNELS)
data = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.float32)
data = (data - (PIXEL_DEPTH / 2.0)) / PIXEL_DEPTH
- data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, 1)
+ data = data.reshape(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS)
return data