aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-22 00:16:05 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-22 00:18:50 -0700
commit88e560d6fadc1cf23519b00a9de5ed7c973536fd (patch)
tree9ec24d27a22fe046bdfb01e5928df3b900571763 /tensorflow/contrib/gan
parentf31939d24e3c544933b98ef48fac9ccac5679e05 (diff)
Use paramaterized tests in `train_test.py`.
PiperOrigin-RevId: 205555784
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD2
-rw-r--r--tensorflow/contrib/gan/python/train_test.py571
2 files changed, 219 insertions, 354 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index c8c2af49d4..781e4ae4d7 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -57,6 +57,7 @@ py_library(
py_test(
name = "train_test",
srcs = ["python/train_test.py"],
+ shard_count = 50,
srcs_version = "PY2AND3",
tags = ["notsan"],
deps = [
@@ -80,6 +81,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index 93a12af944..cd99a33c03 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib import layers
@@ -296,38 +297,24 @@ def get_tensor_pool_fn_for_infogan(pool_size):
return tensor_pool_fn_impl
-class GANModelTest(test.TestCase):
+class GANModelTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_model`."""
- def _test_output_type_helper(self, create_fn, tuple_type):
- self.assertTrue(isinstance(create_fn(), tuple_type))
-
- def test_output_type_gan(self):
- self._test_output_type_helper(get_gan_model, namedtuples.GANModel)
-
- def test_output_type_callable_gan(self):
- self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel)
-
- def test_output_type_infogan(self):
- self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel)
-
- def test_output_type_callable_infogan(self):
- self._test_output_type_helper(get_callable_infogan_model,
- namedtuples.InfoGANModel)
-
- def test_output_type_acgan(self):
- self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel)
-
- def test_output_type_callable_acgan(self):
- self._test_output_type_helper(get_callable_acgan_model,
- namedtuples.ACGANModel)
-
- def test_output_type_cyclegan(self):
- self._test_output_type_helper(get_cyclegan_model, namedtuples.CycleGANModel)
-
- def test_output_type_callable_cyclegan(self):
- self._test_output_type_helper(get_callable_cyclegan_model,
- namedtuples.CycleGANModel)
+ @parameterized.named_parameters(
+ ('gan', get_gan_model, namedtuples.GANModel),
+ ('callable_gan', get_callable_gan_model, namedtuples.GANModel),
+ ('infogan', get_infogan_model, namedtuples.InfoGANModel),
+ ('callable_infogan', get_callable_infogan_model,
+ namedtuples.InfoGANModel),
+ ('acgan', get_acgan_model, namedtuples.ACGANModel),
+ ('callable_acgan', get_callable_acgan_model, namedtuples.ACGANModel),
+ ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel),
+ ('callable_cyclegan', get_callable_cyclegan_model,
+ namedtuples.CycleGANModel),
+ )
+ def test_output_type(self, create_fn, expected_tuple_type):
+ """Test that output type is as expected."""
+ self.assertIsInstance(create_fn(), expected_tuple_type)
def test_no_shape_check(self):
@@ -484,53 +471,55 @@ class StarGANModelTest(test.TestCase):
disc_gen_label.shape)
-class GANLossTest(test.TestCase):
+class GANLossTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_loss`."""
- # Test output type.
- def _test_output_type_helper(self, get_gan_model_fn):
+ @parameterized.named_parameters(
+ ('gan', get_gan_model),
+ ('callable_gan', get_callable_gan_model),
+ ('infogan', get_infogan_model),
+ ('callable_infogan', get_callable_infogan_model),
+ ('acgan', get_acgan_model),
+ ('callable_acgan', get_callable_acgan_model),
+ )
+ def test_output_type(self, get_gan_model_fn):
+ """Test output type."""
loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
- self.assertTrue(isinstance(loss, namedtuples.GANLoss))
- self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
-
- def test_output_type_gan(self):
- self._test_output_type_helper(get_gan_model)
-
- def test_output_type_callable_gan(self):
- self._test_output_type_helper(get_callable_gan_model)
-
- def test_output_type_infogan(self):
- self._test_output_type_helper(get_infogan_model)
-
- def test_output_type_callable_infogan(self):
- self._test_output_type_helper(get_callable_infogan_model)
-
- def test_output_type_acgan(self):
- self._test_output_type_helper(get_acgan_model)
-
- def test_output_type_callable_acgan(self):
- self._test_output_type_helper(get_callable_acgan_model)
-
- def test_output_type_cyclegan(self):
- loss = train.cyclegan_loss(create_cyclegan_model(), add_summaries=True)
- self.assertIsInstance(loss, namedtuples.CycleGANLoss)
+ self.assertIsInstance(loss, namedtuples.GANLoss)
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
- def test_output_type_callable_cyclegan(self):
- loss = train.cyclegan_loss(
- create_callable_cyclegan_model(), add_summaries=True)
+ @parameterized.named_parameters(
+ ('cyclegan', create_cyclegan_model),
+ ('callable_cyclegan', create_callable_cyclegan_model),
+ )
+ def test_cyclegan_output_type(self, get_gan_model_fn):
+ loss = train.cyclegan_loss(get_gan_model_fn(), add_summaries=True)
self.assertIsInstance(loss, namedtuples.CycleGANLoss)
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
- # Test gradient penalty option.
- def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model, False),
+ ('gan_one_sided', create_gan_model, True),
+ ('callable_gan', create_callable_gan_model, False),
+ ('callable_gan_one_sided', create_callable_gan_model, True),
+ ('infogan', create_infogan_model, False),
+ ('infogan_one_sided', create_infogan_model, True),
+ ('callable_infogan', create_callable_infogan_model, False),
+ ('callable_infogan_one_sided', create_callable_infogan_model, True),
+ ('acgan', create_acgan_model, False),
+ ('acgan_one_sided', create_acgan_model, True),
+ ('callable_acgan', create_callable_acgan_model, False),
+ ('callable_acgan_one_sided', create_callable_acgan_model, True),
+ )
+ def test_grad_penalty(self, create_gan_model_fn, one_sided):
+ """Test gradient penalty option."""
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_gp = train.gan_loss(
model,
gradient_penalty_weight=1.0,
gradient_penalty_one_sided=one_sided)
- self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
+ self.assertIsInstance(loss_gp, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
@@ -541,59 +530,28 @@ class GANLossTest(test.TestCase):
[loss.discriminator_loss, loss_gp.discriminator_loss])
self.assertEqual(loss_gen_np, loss_gen_gp_np)
- self.assertTrue(loss_dis_np < loss_dis_gp_np)
-
- def test_grad_penalty_gan(self):
- self._test_grad_penalty_helper(create_gan_model)
-
- def test_grad_penalty_callable_gan(self):
- self._test_grad_penalty_helper(create_callable_gan_model)
-
- def test_grad_penalty_infogan(self):
- self._test_grad_penalty_helper(create_infogan_model)
-
- def test_grad_penalty_callable_infogan(self):
- self._test_grad_penalty_helper(create_callable_infogan_model)
-
- def test_grad_penalty_acgan(self):
- self._test_grad_penalty_helper(create_acgan_model)
-
- def test_grad_penalty_callable_acgan(self):
- self._test_grad_penalty_helper(create_callable_acgan_model)
-
- def test_grad_penalty_one_sided_gan(self):
- self._test_grad_penalty_helper(create_gan_model, one_sided=True)
-
- def test_grad_penalty_one_sided_callable_gan(self):
- self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True)
-
- def test_grad_penalty_one_sided_infogan(self):
- self._test_grad_penalty_helper(create_infogan_model, one_sided=True)
-
- def test_grad_penalty_one_sided_callable_infogan(self):
- self._test_grad_penalty_helper(
- create_callable_infogan_model, one_sided=True)
-
- def test_grad_penalty_one_sided_acgan(self):
- self._test_grad_penalty_helper(create_acgan_model, one_sided=True)
-
- def test_grad_penalty_one_sided_callable_acgan(self):
- self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True)
-
- # Test mutual information penalty option.
- def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
+ self.assertLess(loss_dis_np, loss_dis_gp_np)
+
+ @parameterized.named_parameters(
+ ('infogan', get_infogan_model),
+ ('callable_infogan', get_callable_infogan_model),
+ )
+ def test_mutual_info_penalty(self, create_gan_model_fn):
+ """Test mutual information penalty option."""
train.gan_loss(
create_gan_model_fn(),
mutual_information_penalty_weight=constant_op.constant(1.0))
- def test_mutual_info_penalty_infogan(self):
- self._test_mutual_info_penalty_helper(get_infogan_model)
-
- def test_mutual_info_penalty_callable_infogan(self):
- self._test_mutual_info_penalty_helper(get_callable_infogan_model)
-
- # Test regularization loss.
- def _test_regularization_helper(self, get_gan_model_fn):
+ @parameterized.named_parameters(
+ ('gan', get_gan_model),
+ ('callable_gan', get_callable_gan_model),
+ ('infogan', get_infogan_model),
+ ('callable_infogan', get_callable_infogan_model),
+ ('acgan', get_acgan_model),
+ ('callable_acgan', get_callable_acgan_model),
+ )
+ def test_regularization_helper(self, get_gan_model_fn):
+ """Test regularization loss."""
# Evaluate losses without regularization.
no_reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
@@ -616,33 +574,19 @@ class GANLossTest(test.TestCase):
self.assertEqual(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
self.assertEqual(2.0, reg_loss_dis_np - no_reg_loss_dis_np)
- def test_regularization_gan(self):
- self._test_regularization_helper(get_gan_model)
-
- def test_regularization_callable_gan(self):
- self._test_regularization_helper(get_callable_gan_model)
-
- def test_regularization_infogan(self):
- self._test_regularization_helper(get_infogan_model)
-
- def test_regularization_callable_infogan(self):
- self._test_regularization_helper(get_callable_infogan_model)
-
- def test_regularization_acgan(self):
- self._test_regularization_helper(get_acgan_model)
-
- def test_regularization_callable_acgan(self):
- self._test_regularization_helper(get_callable_acgan_model)
-
- # Test that ACGan models work.
- def _test_acgan_helper(self, create_gan_model_fn):
+ @parameterized.named_parameters(
+ ('notcallable', create_acgan_model),
+ ('callable', create_callable_acgan_model),
+ )
+ def test_acgan(self, create_gan_model_fn):
+ """Test that ACGAN models work."""
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
- self.assertTrue(isinstance(loss, namedtuples.GANLoss))
- self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
- self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))
+ self.assertIsInstance(loss, namedtuples.GANLoss)
+ self.assertIsInstance(loss_ac_gen, namedtuples.GANLoss)
+ self.assertIsInstance(loss_ac_dis, namedtuples.GANLoss)
# Check values.
with self.test_session(use_gpu=True) as sess:
@@ -656,20 +600,18 @@ class GANLossTest(test.TestCase):
loss_ac_dis.discriminator_loss
])
- self.assertTrue(loss_gen_np < loss_dis_np)
+ self.assertLess(loss_gen_np, loss_dis_np)
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
- def test_acgan(self):
- self._test_acgan_helper(create_acgan_model)
-
- def test_callable_acgan(self):
- self._test_acgan_helper(create_callable_acgan_model)
-
- # Test that CycleGan models work.
- def _test_cyclegan_helper(self, create_gan_model_fn):
+ @parameterized.named_parameters(
+ ('notcallable', create_cyclegan_model),
+ ('callable', create_callable_cyclegan_model),
+ )
+ def test_cyclegan(self, create_gan_model_fn):
+ """Test that CycleGan models work."""
model = create_gan_model_fn()
loss = train.cyclegan_loss(model)
self.assertIsInstance(loss, namedtuples.CycleGANLoss)
@@ -690,11 +632,46 @@ class GANLossTest(test.TestCase):
self.assertTrue(np.isscalar(loss_y2x_gen_np))
self.assertTrue(np.isscalar(loss_y2x_dis_np))
- def test_cyclegan(self):
- self._test_cyclegan_helper(create_cyclegan_model)
+ @parameterized.named_parameters(
+ ('gan', create_gan_model),
+ ('callable_gan', create_callable_gan_model),
+ ('infogan', create_infogan_model),
+ ('callable_infogan', create_callable_infogan_model),
+ ('acgan', create_acgan_model),
+ ('callable_acgan', create_callable_acgan_model),
+ )
+ def test_tensor_pool(self, create_gan_model_fn):
+ """Test tensor pool option."""
+ model = create_gan_model_fn()
+ if isinstance(model, namedtuples.InfoGANModel):
+ tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
+ else:
+ tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
+ loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
+ self.assertIsInstance(loss, namedtuples.GANLoss)
+
+ # Check values.
+ with self.test_session(use_gpu=True) as sess:
+ variables.global_variables_initializer().run()
+ for _ in range(10):
+ sess.run([loss.generator_loss, loss.discriminator_loss])
+
+ def test_doesnt_crash_when_in_nested_scope(self):
+ with variable_scope.variable_scope('outer_scope'):
+ gan_model = train.gan_model(
+ generator_model,
+ discriminator_model,
+ real_data=array_ops.zeros([1, 2]),
+ generator_inputs=random_ops.random_normal([1, 2]))
+
+ # This should work inside a scope.
+ train.gan_loss(gan_model, gradient_penalty_weight=1.0)
+
+ # This should also work outside a scope.
+ train.gan_loss(gan_model, gradient_penalty_weight=1.0)
+
- def test_callable_cyclegan(self):
- self._test_cyclegan_helper(create_callable_cyclegan_model)
+class TensorPoolAdjusteModelTest(test.TestCase):
def _check_tensor_pool_adjusted_model_outputs(self, tensor1, tensor2,
pool_size):
@@ -714,115 +691,77 @@ class GANLossTest(test.TestCase):
# pool).
self.assertTrue(any([(v == t2).all() for v in history_values]))
- # Test `_tensor_pool_adjusted_model` for gan model.
- def test_tensor_pool_adjusted_model_gan(self):
- model = create_gan_model()
-
- new_model = train._tensor_pool_adjusted_model(model, None)
+ def _make_new_model_and_check(self, model, pool_size,
+ pool_fn=get_tensor_pool_fn):
+ new_model = train._tensor_pool_adjusted_model(
+ model, pool_fn(pool_size=pool_size))
# 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
+ self.assertIsNot(new_model.discriminator_gen_outputs,
+ model.discriminator_gen_outputs)
+
+ return new_model
+
+ def test_tensor_pool_adjusted_model_no_pool(self):
+ """Test `_tensor_pool_adjusted_model` for no pool size."""
+ model = create_gan_model()
+ new_model = train._tensor_pool_adjusted_model(model, None)
+
+ # Check values.
self.assertIs(new_model.discriminator_gen_outputs,
model.discriminator_gen_outputs)
+ def test_tensor_pool_adjusted_model_gan(self):
+ """Test `_tensor_pool_adjusted_model` for gan model."""
pool_size = 5
- new_model = train._tensor_pool_adjusted_model(
- model, get_tensor_pool_fn(pool_size=pool_size))
- self.assertIsNot(new_model.discriminator_gen_outputs,
- model.discriminator_gen_outputs)
+ model = create_gan_model()
+ new_model = self._make_new_model_and_check(model, pool_size)
+
# Check values.
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
- # Test _tensor_pool_adjusted_model for infogan model.
def test_tensor_pool_adjusted_model_infogan(self):
+ """Test _tensor_pool_adjusted_model for infogan model."""
+ pool_size = 5
model = create_infogan_model()
+ new_model = self._make_new_model_and_check(
+ model, pool_size, pool_fn=get_tensor_pool_fn_for_infogan)
- pool_size = 5
- new_model = train._tensor_pool_adjusted_model(
- model, get_tensor_pool_fn_for_infogan(pool_size=pool_size))
- # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
- self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
- self.assertIsNot(new_model.discriminator_gen_outputs,
- model.discriminator_gen_outputs)
+ # Check values.
self.assertIsNot(new_model.predicted_distributions,
model.predicted_distributions)
- # Check values.
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
- # Test _tensor_pool_adjusted_model for acgan model.
def test_tensor_pool_adjusted_model_acgan(self):
+ """Test _tensor_pool_adjusted_model for acgan model."""
+ pool_size = 5
model = create_acgan_model()
+ new_model = self._make_new_model_and_check(model, pool_size)
- pool_size = 5
- new_model = train._tensor_pool_adjusted_model(
- model, get_tensor_pool_fn(pool_size=pool_size))
- # 'Generator/dummy_g:0' and 'Discriminator/dummy_d:0'
- self.assertEqual(2, len(ops.get_collection(ops.GraphKeys.VARIABLES)))
- self.assertIsNot(new_model.discriminator_gen_outputs,
- model.discriminator_gen_outputs)
+ # Check values.
self.assertIsNot(new_model.discriminator_gen_classification_logits,
model.discriminator_gen_classification_logits)
- # Check values.
self._check_tensor_pool_adjusted_model_outputs(
model.discriminator_gen_outputs, new_model.discriminator_gen_outputs,
pool_size)
- # Test tensor pool.
- def _test_tensor_pool_helper(self, create_gan_model_fn):
- model = create_gan_model_fn()
- if isinstance(model, namedtuples.InfoGANModel):
- tensor_pool_fn = get_tensor_pool_fn_for_infogan(pool_size=5)
- else:
- tensor_pool_fn = get_tensor_pool_fn(pool_size=5)
- loss = train.gan_loss(model, tensor_pool_fn=tensor_pool_fn)
- self.assertTrue(isinstance(loss, namedtuples.GANLoss))
-
- # Check values.
- with self.test_session(use_gpu=True) as sess:
- variables.global_variables_initializer().run()
- for _ in range(10):
- sess.run([loss.generator_loss, loss.discriminator_loss])
-
- def test_tensor_pool_gan(self):
- self._test_tensor_pool_helper(create_gan_model)
-
- def test_tensor_pool_callable_gan(self):
- self._test_tensor_pool_helper(create_callable_gan_model)
-
- def test_tensor_pool_infogan(self):
- self._test_tensor_pool_helper(create_infogan_model)
-
- def test_tensor_pool_callable_infogan(self):
- self._test_tensor_pool_helper(create_callable_infogan_model)
-
- def test_tensor_pool_acgan(self):
- self._test_tensor_pool_helper(create_acgan_model)
-
- def test_tensor_pool_callable_acgan(self):
- self._test_tensor_pool_helper(create_callable_acgan_model)
-
- def test_doesnt_crash_when_in_nested_scope(self):
- with variable_scope.variable_scope('outer_scope'):
- gan_model = train.gan_model(
- generator_model,
- discriminator_model,
- real_data=array_ops.zeros([1, 2]),
- generator_inputs=random_ops.random_normal([1, 2]))
-
- # This should work inside a scope.
- train.gan_loss(gan_model, gradient_penalty_weight=1.0)
-
- # This should also work outside a scope.
- train.gan_loss(gan_model, gradient_penalty_weight=1.0)
-
-class GANTrainOpsTest(test.TestCase):
+class GANTrainOpsTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_train_ops`."""
- def _test_output_type_helper(self, create_gan_model_fn):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model),
+ ('callable_gan', create_callable_gan_model),
+ ('infogan', create_infogan_model),
+ ('callable_infogan', create_callable_infogan_model),
+ ('acgan', create_acgan_model),
+ ('callable_acgan', create_callable_acgan_model),
+ )
+ def test_output_type(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
@@ -836,28 +775,24 @@ class GANTrainOpsTest(test.TestCase):
summarize_gradients=True,
colocate_gradients_with_ops=True)
- self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
-
- def test_output_type_gan(self):
- self._test_output_type_helper(create_gan_model)
-
- def test_output_type_callable_gan(self):
- self._test_output_type_helper(create_callable_gan_model)
-
- def test_output_type_infogan(self):
- self._test_output_type_helper(create_infogan_model)
-
- def test_output_type_callable_infogan(self):
- self._test_output_type_helper(create_callable_infogan_model)
-
- def test_output_type_acgan(self):
- self._test_output_type_helper(create_acgan_model)
-
- def test_output_type_callable_acgan(self):
- self._test_output_type_helper(create_callable_acgan_model)
+ self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# TODO(joelshor): Add a test to check that custom update op is run.
- def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model, False),
+ ('gan_provideupdates', create_gan_model, True),
+ ('callable_gan', create_callable_gan_model, False),
+ ('callable_gan_provideupdates', create_callable_gan_model, True),
+ ('infogan', create_infogan_model, False),
+ ('infogan_provideupdates', create_infogan_model, True),
+ ('callable_infogan', create_callable_infogan_model, False),
+ ('callable_infogan_provideupdates', create_callable_infogan_model, True),
+ ('acgan', create_acgan_model, False),
+ ('acgan_provideupdates', create_acgan_model, True),
+ ('callable_acgan', create_callable_acgan_model, False),
+ ('callable_acgan_provideupdates', create_callable_acgan_model, True),
+ )
+ def test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
model = create_gan_model_fn()
loss = train.gan_loss(model)
@@ -904,45 +839,16 @@ class GANTrainOpsTest(test.TestCase):
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(1, dis_update_count.eval())
- def test_unused_update_ops_gan(self):
- self._test_unused_update_ops(create_gan_model, False)
-
- def test_unused_update_ops_gan_provideupdates(self):
- self._test_unused_update_ops(create_gan_model, True)
-
- def test_unused_update_ops_callable_gan(self):
- self._test_unused_update_ops(create_callable_gan_model, False)
-
- def test_unused_update_ops_callable_gan_provideupdates(self):
- self._test_unused_update_ops(create_callable_gan_model, True)
-
- def test_unused_update_ops_infogan(self):
- self._test_unused_update_ops(create_infogan_model, False)
-
- def test_unused_update_ops_infogan_provideupdates(self):
- self._test_unused_update_ops(create_infogan_model, True)
-
- def test_unused_update_ops_callable_infogan(self):
- self._test_unused_update_ops(create_callable_infogan_model, False)
-
- def test_unused_update_ops_callable_infogan_provideupdates(self):
- self._test_unused_update_ops(create_callable_infogan_model, True)
-
- def test_unused_update_ops_acgan(self):
- self._test_unused_update_ops(create_acgan_model, False)
-
- def test_unused_update_ops_acgan_provideupdates(self):
- self._test_unused_update_ops(create_acgan_model, True)
-
- def test_unused_update_ops_callable_acgan(self):
- self._test_unused_update_ops(create_callable_acgan_model, False)
-
- def test_unused_update_ops_callable_acgan_provideupdates(self):
- self._test_unused_update_ops(create_callable_acgan_model, True)
-
- def _test_sync_replicas_helper(self,
- create_gan_model_fn,
- create_global_step=False):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model, False),
+ ('callable_gan', create_callable_gan_model, False),
+ ('infogan', create_infogan_model, False),
+ ('callable_infogan', create_callable_infogan_model, False),
+ ('acgan', create_acgan_model, False),
+ ('callable_acgan', create_callable_acgan_model, False),
+ ('gan_canbeint32', create_gan_model, True),
+ )
+ def test_sync_replicas(self, create_gan_model_fn, create_global_step):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
@@ -956,7 +862,7 @@ class GANTrainOpsTest(test.TestCase):
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model, loss, generator_optimizer=g_opt, discriminator_optimizer=d_opt)
- self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
+ self.assertIsInstance(train_ops, namedtuples.GANTrainOps)
# No new trainable variables should have been added.
self.assertEqual(num_trainable_vars,
len(variables_lib.get_trainable_variables()))
@@ -994,29 +900,8 @@ class GANTrainOpsTest(test.TestCase):
coord.request_stop()
coord.join(g_threads + d_threads)
- def test_sync_replicas_gan(self):
- self._test_sync_replicas_helper(create_gan_model)
-
- def test_sync_replicas_callable_gan(self):
- self._test_sync_replicas_helper(create_callable_gan_model)
-
- def test_sync_replicas_infogan(self):
- self._test_sync_replicas_helper(create_infogan_model)
-
- def test_sync_replicas_callable_infogan(self):
- self._test_sync_replicas_helper(create_callable_infogan_model)
-
- def test_sync_replicas_acgan(self):
- self._test_sync_replicas_helper(create_acgan_model)
-
- def test_sync_replicas_callable_acgan(self):
- self._test_sync_replicas_helper(create_callable_acgan_model)
- def test_global_step_can_be_int32(self):
- self._test_sync_replicas_helper(create_gan_model, create_global_step=True)
-
-
-class GANTrainTest(test.TestCase):
+class GANTrainTest(test.TestCase, parameterized.TestCase):
"""Tests for `gan_train`."""
def _gan_train_ops(self, generator_add, discriminator_add):
@@ -1032,7 +917,15 @@ class GANTrainTest(test.TestCase):
global_step_inc_op=step.assign_add(1))
return train_ops
- def _test_run_helper(self, create_gan_model_fn):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model),
+ ('callable_gan', create_callable_gan_model),
+ ('infogan', create_infogan_model),
+ ('callable_infogan', create_callable_infogan_model),
+ ('acgan', create_acgan_model),
+ ('callable_acgan', create_callable_acgan_model),
+ )
+ def test_run_helper(self, create_gan_model_fn):
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
@@ -1048,26 +941,12 @@ class GANTrainTest(test.TestCase):
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
- def test_run_gan(self):
- self._test_run_helper(create_gan_model)
-
- def test_run_callable_gan(self):
- self._test_run_helper(create_callable_gan_model)
-
- def test_run_infogan(self):
- self._test_run_helper(create_infogan_model)
-
- def test_run_callable_infogan(self):
- self._test_run_helper(create_callable_infogan_model)
-
- def test_run_acgan(self):
- self._test_run_helper(create_acgan_model)
-
- def test_run_callable_acgan(self):
- self._test_run_helper(create_callable_acgan_model)
-
- # Test multiple train steps.
- def _test_multiple_steps_helper(self, get_hooks_fn_fn):
+ @parameterized.named_parameters(
+ ('seq_train_steps', train.get_sequential_train_hooks),
+ ('efficient_seq_train_steps', train.get_joint_train_hooks),
+ )
+ def test_multiple_steps(self, get_hooks_fn_fn):
+ """Test multiple train steps."""
train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
train_steps = namedtuples.GANTrainSteps(
generator_train_steps=3, discriminator_train_steps=4)
@@ -1080,12 +959,6 @@ class GANTrainTest(test.TestCase):
self.assertTrue(np.isscalar(final_step))
self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
- def test_multiple_steps_seq_train_steps(self):
- self._test_multiple_steps_helper(train.get_sequential_train_hooks)
-
- def test_multiple_steps_efficient_seq_train_steps(self):
- self._test_multiple_steps_helper(train.get_joint_train_hooks)
-
def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
step = training_util.create_global_step()
train_ops = namedtuples.GANTrainOps(
@@ -1105,10 +978,18 @@ class GANTrainTest(test.TestCase):
self.assertEqual(17.0, final_loss)
-class PatchGANTest(test.TestCase):
+class PatchGANTest(test.TestCase, parameterized.TestCase):
"""Tests that functions work on PatchGAN style output."""
- def _test_patchgan_helper(self, create_gan_model_fn):
+ @parameterized.named_parameters(
+ ('gan', create_gan_model),
+ ('callable_gan', create_callable_gan_model),
+ ('infogan', create_infogan_model),
+ ('callable_infogan', create_callable_infogan_model),
+ ('acgan', create_acgan_model),
+ ('callable_acgan', create_callable_acgan_model),
+ )
+ def test_patchgan(self, create_gan_model_fn):
"""Ensure that patch-based discriminators work end-to-end."""
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
@@ -1125,24 +1006,6 @@ class PatchGANTest(test.TestCase):
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
- def test_patchgan_gan(self):
- self._test_patchgan_helper(create_gan_model)
-
- def test_patchgan_callable_gan(self):
- self._test_patchgan_helper(create_callable_gan_model)
-
- def test_patchgan_infogan(self):
- self._test_patchgan_helper(create_infogan_model)
-
- def test_patchgan_callable_infogan(self):
- self._test_patchgan_helper(create_callable_infogan_model)
-
- def test_patchgan_acgan(self):
- self._test_patchgan_helper(create_acgan_model)
-
- def test_patchgan_callable_acgan(self):
- self._test_patchgan_helper(create_callable_acgan_model)
-
if __name__ == '__main__':
test.main()