From 2c858368c8c4b7e82c8d134786026a62a72d2676 Mon Sep 17 00:00:00 2001 From: Randy West Date: Mon, 18 Dec 2017 18:22:03 -0500 Subject: Compute test accuracy in batches to avoid OOM on GPUs. Reported here: https://github.com/tensorflow/tensorflow/issues/136 Alternative to this for mnist_deep.py: https://github.com/tensorflow/tensorflow/pull/157 --- tensorflow/examples/tutorials/mnist/mnist_deep.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) (limited to 'tensorflow/examples') diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py index 1e0294db27..2699738735 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_deep.py +++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -34,6 +34,8 @@ from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf +import numpy + FLAGS = None @@ -164,8 +166,13 @@ def main(_): print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) - print('test accuracy %g' % accuracy.eval(feed_dict={ - x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) + # compute in batches to avoid OOM on GPUs + accuracy_l = [] + for i in range(50): + batch = mnist.test.next_batch(500, shuffle=False) + accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})) + print('test accuracy %g' % numpy.mean(accuracy_l)) + if __name__ == '__main__': parser = argparse.ArgumentParser() -- cgit v1.2.3 From 3f18817317940253e6ec0e6b412492c5add5927b Mon Sep 17 00:00:00 2001 From: Randy West Date: Mon, 18 Dec 2017 23:18:30 -0500 Subject: Fix basic arithmetic fail + make loop pythonic --- tensorflow/examples/tutorials/mnist/mnist_deep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'tensorflow/examples') diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py index 2699738735..47d2777813 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_deep.py +++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -168,7 +168,7 @@ def main(_): # compute in batches to avoid OOM on GPUs accuracy_l = [] - for i in range(50): + for _ in range(20): batch = mnist.test.next_batch(500, shuffle=False) accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})) print('test accuracy %g' % numpy.mean(accuracy_l)) -- cgit v1.2.3 From 9f3bd2cf1eccdc76ed1934ade96c6cd4464bb8b2 Mon Sep 17 00:00:00 2001 From: Mark Daoust Date: Fri, 6 Jul 2018 05:46:42 -0700 Subject: lint fix --- tensorflow/examples/tutorials/mnist/mnist_deep.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) (limited to 'tensorflow/examples') diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py index 47d2777813..5d8d8d84fe 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_deep.py +++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -170,7 +170,9 @@ def main(_): accuracy_l = [] for _ in range(20): batch = mnist.test.next_batch(500, shuffle=False) - accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], y_: batch[1], keep_prob: 1.0})) + accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], + y_: batch[1], + keep_prob: 1.0})) print('test accuracy %g' % numpy.mean(accuracy_l)) -- cgit v1.2.3