diff options
author | 2018-01-10 12:58:02 -0800 | |
---|---|---|
committer | 2018-01-10 13:02:28 -0800 | |
commit | d2082810c4e6fdcd501fce7abc86e9be4c36cb3a (patch) | |
tree | cece52d2b44e245dd78579838a7b81186839202d /tensorflow/contrib/gan | |
parent | 0d5fb1036e0e8d99a942e7a5234feaef021607f5 (diff) |
Fix flaky training tests. Reenable the tests.
PiperOrigin-RevId: 181505090
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/BUILD | 6 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/train_test.py | 6 |
2 files changed, 3 insertions, 9 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index de6c672504..b355a79b1a 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -55,12 +55,6 @@ py_test( name = "train_test", srcs = ["python/train_test.py"], srcs_version = "PY2AND3", - tags = [ - "manual", # b/71801546 - "no_oss", - "notap", - "notsan", - ], deps = [ ":features", ":namedtuples", diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index 3411657cad..58704e6859 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -455,7 +455,7 @@ class GANLossTest(test.TestCase): new_model = train._tensor_pool_adjusted_model(model, None) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2) + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIs(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) @@ -477,7 +477,7 @@ class GANLossTest(test.TestCase): new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn_for_infogan(pool_size=pool_size)) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2) + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) self.assertIsNot(new_model.predicted_distributions, @@ -495,7 +495,7 @@ class GANLossTest(test.TestCase): new_model = train._tensor_pool_adjusted_model( model, get_tensor_pool_fn(pool_size=pool_size)) # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0' - self.assertLen(ops.get_collection(ops.GraphKeys.VARIABLES), 2) + self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES))) self.assertIsNot(new_model.discriminator_gen_outputs, model.discriminator_gen_outputs) self.assertIsNot(new_model.discriminator_gen_classification_logits, |