aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar David G. Andersen <dga@google.com>2016-02-05 16:43:12 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-02-05 17:06:25 -0800
commit136ac3c58ea2493ea6cfc90d3f186b8bc1929873 (patch)
tree67367945a87666e1c5f36c4d1f0ce44d9f741fd0
parent8f80ef184034ef536efa4aee1dd15cfaf84a5c34 (diff)
Convert convolutional.py example to use new sparse_softmax_cross_entropy_with_logits op
Change: 113997793
-rw-r--r--tensorflow/models/image/mnist/convolutional.py20
1 files changed, 9 insertions, 11 deletions
diff --git a/tensorflow/models/image/mnist/convolutional.py b/tensorflow/models/image/mnist/convolutional.py
index df0ca22063..edceb2a1ec 100644
--- a/tensorflow/models/image/mnist/convolutional.py
+++ b/tensorflow/models/image/mnist/convolutional.py
@@ -81,14 +81,13 @@ def extract_data(filename, num_images):
def extract_labels(filename, num_images):
- """Extract the labels into a 1-hot matrix [image index, label index]."""
+ """Extract the labels into a vector of int64 label IDs."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
bytestream.read(8)
buf = bytestream.read(1 * num_images)
- labels = numpy.frombuffer(buf, dtype=numpy.uint8)
- # Convert to dense 1-hot representation.
- return (numpy.arange(NUM_LABELS) == labels[:, None]).astype(numpy.float32)
+ labels = numpy.frombuffer(buf, dtype=numpy.uint8).astype(numpy.int64)
+ return labels
def fake_data(num_images):
@@ -96,19 +95,19 @@ def fake_data(num_images):
data = numpy.ndarray(
shape=(num_images, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS),
dtype=numpy.float32)
- labels = numpy.zeros(shape=(num_images, NUM_LABELS), dtype=numpy.float32)
+ labels = numpy.zeros(shape=(num_images,), dtype=numpy.int64)
for image in xrange(num_images):
label = image % 2
data[image, :, :, 0] = label - 0.5
- labels[image, label] = 1.0
+ labels[image] = label
return data, labels
def error_rate(predictions, labels):
- """Return the error rate based on dense predictions and 1-hot labels."""
+ """Return the error rate based on dense predictions and sparse labels."""
return 100.0 - (
100.0 *
- numpy.sum(numpy.argmax(predictions, 1) == numpy.argmax(labels, 1)) /
+ numpy.sum(numpy.argmax(predictions, 1) == labels) /
predictions.shape[0])
@@ -146,8 +145,7 @@ def main(argv=None): # pylint: disable=unused-argument
train_data_node = tf.placeholder(
tf.float32,
shape=(BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
- train_labels_node = tf.placeholder(tf.float32,
- shape=(BATCH_SIZE, NUM_LABELS))
+ train_labels_node = tf.placeholder(tf.int64, shape=(BATCH_SIZE,))
eval_data = tf.placeholder(
tf.float32,
shape=(EVAL_BATCH_SIZE, IMAGE_SIZE, IMAGE_SIZE, NUM_CHANNELS))
@@ -222,7 +220,7 @@ def main(argv=None): # pylint: disable=unused-argument
# Training computation: logits + cross-entropy loss.
logits = model(train_data_node, True)
- loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
+ loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(
logits, train_labels_node))
# L2 regularization for the fully connected parameters.