aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/examples
diff options
context:
space:
mode:
authorGravatar Randy West <randywest55@gmail.com>2017-12-18 18:22:03 -0500
committerGravatar Randy West <randywest55@gmail.com>2017-12-18 19:42:24 -0500
commit2c858368c8c4b7e82c8d134786026a62a72d2676 (patch)
tree3f58ff0a287e79c8f6c382e2e6fd969fade02757 /tensorflow/examples
parentacaabdfe587de35ee66a612b3bbcbafef2dcca89 (diff)
Compute test accuracy in batches to avoid OOM on GPUs.
Diffstat (limited to 'tensorflow/examples')
-rw-r--r--tensorflow/examples/tutorials/mnist/mnist_deep.py11
1 files changed, 9 insertions, 2 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py
index 1e0294db27..2699738735 100644
--- a/tensorflow/examples/tutorials/mnist/mnist_deep.py
+++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py
@@ -34,6 +34,8 @@ from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
+import numpy
+
FLAGS = None
@@ -164,8 +166,13 @@ def main(_):
print('step %d, training accuracy %g' % (i, train_accuracy))
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
- print('test accuracy %g' % accuracy.eval(feed_dict={
- x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))
+ # compute in batches to avoid OOM on GPUs
+ accuracy_l = []
+ for i in range(50):
+ batch = mnist.test.next_batch(500, shuffle=False)
+ accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0}))
+ print('test accuracy %g' % numpy.mean(accuracy_l))
+
if __name__ == '__main__':
parser = argparse.ArgumentParser()