diff options
Diffstat (limited to 'tensorflow/g3doc/tutorials/mnist/mnist_softmax.py')
-rw-r--r-- | tensorflow/g3doc/tutorials/mnist/mnist_softmax.py | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py new file mode 100644 index 0000000000..640ea29dac --- /dev/null +++ b/tensorflow/g3doc/tutorials/mnist/mnist_softmax.py @@ -0,0 +1,33 @@ +"""A very simple MNIST classifer. + +See extensive documentation at ??????? (insert public URL) +""" + +# Import data +import input_data +mnist = input_data.read_data_sets("/tmp/data/", one_hot=True) + +import tensorflow as tf +sess = tf.InteractiveSession() + +# Create the model +x = tf.placeholder("float", [None, 784]) +W = tf.Variable(tf.zeros([784,10])) +b = tf.Variable(tf.zeros([10])) +y = tf.nn.softmax(tf.matmul(x,W) + b) + +# Define loss and optimizer +y_ = tf.placeholder("float", [None,10]) +cross_entropy = -tf.reduce_sum(y_*tf.log(y)) +train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) + +# Train +tf.initialize_all_variables().run() +for i in range(1000): + batch_xs, batch_ys = mnist.train.next_batch(100) + train_step.run({x: batch_xs, y_: batch_ys}) + +# Test trained model +correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) +accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) +print accuracy.eval({x: mnist.test.images, y_: mnist.test.labels}) |