diff options
author | 2018-10-03 16:36:23 -0700 | |
---|---|---|
committer | 2018-10-03 16:43:42 -0700 | |
commit | 9801b8810e07859141d4417746317cc3dbebc227 (patch) | |
tree | c8891fc08ce368a32af7fc45443114be73696cda /tensorflow/contrib/eager | |
parent | 207bea0e35ab635e66137520963761a6e94354ea (diff) |
Reduce batch sizes for some eager tests to prevert OOMs in OSS runs
PiperOrigin-RevId: 215651413
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r-- | tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/eager/python/examples/revnet/revnet_test.py | 3 |
2 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py index 551c76b0df..f3bb978875 100644 --- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py +++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py @@ -51,7 +51,9 @@ def random_batch(batch_size): class ResNet50GraphTest(tf.test.TestCase): def testApply(self): - batch_size = 64 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 8 with tf.Graph().as_default(): images = tf.placeholder(tf.float32, image_shape(None)) model = resnet50.ResNet50(data_format()) @@ -63,7 +65,7 @@ class ResNet50GraphTest(tf.test.TestCase): sess.run(init) np_images, _ = random_batch(batch_size) out = sess.run(predictions, feed_dict={images: np_images}) - self.assertAllEqual([64, 1000], out.shape) + self.assertAllEqual([batch_size, 1000], out.shape) def testTrainWithSummary(self): with tf.Graph().as_default(): @@ -87,7 +89,9 @@ class ResNet50GraphTest(tf.test.TestCase): init = tf.global_variables_initializer() self.assertEqual(321, len(tf.global_variables())) - batch_size = 32 + # Use small batches for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + batch_size = 2 with tf.Session() as sess: sess.run(init) sess.run(tf.contrib.summary.summary_writer_initializer_op()) diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py index 6a921e1997..4f4cc3af6f 100644 --- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py +++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py @@ -50,6 +50,9 @@ class RevNetTest(tf.test.TestCase): # Reconstruction could cause numerical error, use double precision for tests config.dtype = tf.float64 config.fused = False # Fused batch norm does not support tf.float64 + # Reduce the batch size for tests because the OSS version runs + # in constrained GPU environment with 1-2GB of memory. + config.batch_size = 2 shape = (config.batch_size,) + config.input_shape self.model = revnet.RevNet(config=config) self.x = tf.random_normal(shape=shape, dtype=tf.float64) |