aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/eager
diff options
context:
space:
mode:
authorGravatar Todd Wang <toddw@google.com>2018-10-03 16:36:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-03 16:43:42 -0700
commit9801b8810e07859141d4417746317cc3dbebc227 (patch)
treec8891fc08ce368a32af7fc45443114be73696cda /tensorflow/contrib/eager
parent207bea0e35ab635e66137520963761a6e94354ea (diff)
Reduce batch sizes for some eager tests to prevert OOMs in OSS runs
PiperOrigin-RevId: 215651413
Diffstat (limited to 'tensorflow/contrib/eager')
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py10
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/revnet_test.py3
2 files changed, 10 insertions, 3 deletions
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
index 551c76b0df..f3bb978875 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_graph_test.py
@@ -51,7 +51,9 @@ def random_batch(batch_size):
class ResNet50GraphTest(tf.test.TestCase):
def testApply(self):
- batch_size = 64
+ # Use small batches for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ batch_size = 8
with tf.Graph().as_default():
images = tf.placeholder(tf.float32, image_shape(None))
model = resnet50.ResNet50(data_format())
@@ -63,7 +65,7 @@ class ResNet50GraphTest(tf.test.TestCase):
sess.run(init)
np_images, _ = random_batch(batch_size)
out = sess.run(predictions, feed_dict={images: np_images})
- self.assertAllEqual([64, 1000], out.shape)
+ self.assertAllEqual([batch_size, 1000], out.shape)
def testTrainWithSummary(self):
with tf.Graph().as_default():
@@ -87,7 +89,9 @@ class ResNet50GraphTest(tf.test.TestCase):
init = tf.global_variables_initializer()
self.assertEqual(321, len(tf.global_variables()))
- batch_size = 32
+ # Use small batches for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ batch_size = 2
with tf.Session() as sess:
sess.run(init)
sess.run(tf.contrib.summary.summary_writer_initializer_op())
diff --git a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
index 6a921e1997..4f4cc3af6f 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/revnet_test.py
@@ -50,6 +50,9 @@ class RevNetTest(tf.test.TestCase):
# Reconstruction could cause numerical error, use double precision for tests
config.dtype = tf.float64
config.fused = False # Fused batch norm does not support tf.float64
+ # Reduce the batch size for tests because the OSS version runs
+ # in constrained GPU environment with 1-2GB of memory.
+ config.batch_size = 2
shape = (config.batch_size,) + config.input_shape
self.model = revnet.RevNet(config=config)
self.x = tf.random_normal(shape=shape, dtype=tf.float64)