diff options
author | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 15:18:54 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-20 15:18:54 -0700 |
commit | a3d814cd9100556a6e2e1468f1a9981820b4203c (patch) | |
tree | 78efca7a340d99559cf3ebb710b16ef7979b7ccb /tensorflow/examples | |
parent | 88757f6422cf14b2bb6beec14da390a092d66151 (diff) | |
parent | 9f3bd2cf1eccdc76ed1934ade96c6cd4464bb8b2 (diff) |
Merge pull request #15469 from thisisrandy:master
PiperOrigin-RevId: 205461271
Diffstat (limited to 'tensorflow/examples')
-rw-r--r-- | tensorflow/examples/tutorials/mnist/mnist_deep.py | 13 |
1 files changed, 11 insertions, 2 deletions
diff --git a/tensorflow/examples/tutorials/mnist/mnist_deep.py b/tensorflow/examples/tutorials/mnist/mnist_deep.py index 1e0294db27..5d8d8d84fe 100644 --- a/tensorflow/examples/tutorials/mnist/mnist_deep.py +++ b/tensorflow/examples/tutorials/mnist/mnist_deep.py @@ -34,6 +34,8 @@ from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf +import numpy + FLAGS = None @@ -164,8 +166,15 @@ def main(_): print('step %d, training accuracy %g' % (i, train_accuracy)) train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5}) - print('test accuracy %g' % accuracy.eval(feed_dict={ - x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})) + # compute in batches to avoid OOM on GPUs + accuracy_l = [] + for _ in range(20): + batch = mnist.test.next_batch(500, shuffle=False) + accuracy_l.append(accuracy.eval(feed_dict={x: batch[0], + y_: batch[1], + keep_prob: 1.0})) + print('test accuracy %g' % numpy.mean(accuracy_l)) + if __name__ == '__main__': parser = argparse.ArgumentParser() |