aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-01-10 12:58:02 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-01-10 13:02:28 -0800
commitd2082810c4e6fdcd501fce7abc86e9be4c36cb3a (patch)
treecece52d2b44e245dd78579838a7b81186839202d /tensorflow/contrib/gan
parent0d5fb1036e0e8d99a942e7a5234feaef021607f5 (diff)
Fix flaky training tests. Reenable the tests.
PiperOrigin-RevId: 181505090
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD6
-rw-r--r--tensorflow/contrib/gan/python/train_test.py6
2 files changed, 3 insertions, 9 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index de6c672504..b355a79b1a 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -55,12 +55,6 @@ py_test(
name = "train_test",
srcs = ["python/train_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "manual", # b/71801546
- "no_oss",
- "notap",
- "notsan",
- ],
deps = [
":features",
":namedtuples",
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 3411657cad..58704e6859 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -455,7 +455,7 @@ class GANLossTest(test.TestCase):
new_model = train._tensor_pool_adjusted_model(model, None)
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
- self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2)
+ self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
self.assertIs(new_model.discriminator_gen_outputs,
model.discriminator_gen_outputs)
@@ -477,7 +477,7 @@ class GANLossTest(test.TestCase):
new_model = train._tensor_pool_adjusted_model(
model, get_tensor_pool_fn_for_infogan(pool_size=pool_size))
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
- self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2)
+ self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
self.assertIsNot(new_model.discriminator_gen_outputs,
model.discriminator_gen_outputs)
self.assertIsNot(new_model.predicted_distributions,
@@ -495,7 +495,7 @@ class GANLossTest(test.TestCase):
new_model = train._tensor_pool_adjusted_model(
model, get_tensor_pool_fn(pool_size=pool_size))
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
- self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2)
+ self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
self.assertIsNot(new_model.discriminator_gen_outputs,
model.discriminator_gen_outputs)
self.assertIsNot(new_model.discriminator_gen_classification_logits,