aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:18:54 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-20 15:18:54 -0700
commita3d814cd9100556a6e2e1468f1a9981820b4203c (patch)
tree78efca7a340d99559cf3ebb710b16ef7979b7ccb /tensorflow/examples
parent88757f6422cf14b2bb6beec14da390a092d66151 (diff)
parent9f3bd2cf1eccdc76ed1934ade96c6cd4464bb8b2 (diff)
Merge pull request #15469 from thisisrandy:master
PiperOrigin-RevId: 205461271
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_deep.py13
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py
index 1e0294db27..5d8d8d84fe 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_deep.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py
@@ -34,6 +34,8 @@ from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
+import numpy
+
FLAGS = None
@@ -164,8 +166,15 @@ def main(_):
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
- print('test accuracy %g' % accuracy.eval(feed_dict={
- x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
+ # compute in batches to avoid OOM on GPUs
+ accuracy_l = []
+ for _ in range(20):
+ batch = mnist.test.next_batch(500, shuffle=False)
+ accuracy_l.append(accuracy.eval(feed_dict={x: batch[0],
+ y_: batch[1],
+ keep_prob: 1.0}))
+ print('test accuracy %g' % numpy.mean(accuracy_l))
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()