diff options
Diffstat (limited to 'tensorflow/contrib/gan/python/train_test.py')
-rw-r--r-- | tensorflow/contrib/gan/python/train_test.py | 25 |
1 files changed, 23 insertions, 2 deletions
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index f9bdaa74c9..3ebbe55d05 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -359,10 +359,12 @@ class GANLossTest(test.TestCase): self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) # Test gradient penalty option. - def _test_grad_penalty_helper(self, create_gan_model_fn): + def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False): model = create_gan_model_fn() loss = train.gan_loss(model) - loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) + loss_gp = train.gan_loss(model, + gradient_penalty_weight=1.0, + gradient_penalty_one_sided=one_sided) self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) # Check values. @@ -394,6 +396,25 @@ class GANLossTest(test.TestCase): def test_grad_penalty_callable_acgan(self): self._test_grad_penalty_helper(create_callable_acgan_model) + def test_grad_penalty_one_sided_gan(self): + self._test_grad_penalty_helper(create_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_gan(self): + self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True) + + def test_grad_penalty_one_sided_infogan(self): + self._test_grad_penalty_helper(create_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_infogan(self): + self._test_grad_penalty_helper( + create_callable_infogan_model, one_sided=True) + + def test_grad_penalty_one_sided_acgan(self): + self._test_grad_penalty_helper(create_acgan_model, one_sided=True) + + def test_grad_penalty_one_sided_callable_acgan(self): + self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True) + # Test mutual information penalty option. def _test_mutual_info_penalty_helper(self, create_gan_model_fn): train.gan_loss(create_gan_model_fn(), |