aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/train_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/train_test.py')
-rw-r--r--tensorflow/contrib/gan/python/train_test.py25
1 files changed, 23 insertions, 2 deletions
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index f9bdaa74c9..3ebbe55d05 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -359,10 +359,12 @@ class GANLossTest(test.TestCase):
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
# Test gradient penalty option.
- def _test_grad_penalty_helper(self, create_gan_model_fn):
+ def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False):
model = create_gan_model_fn()
loss = train.gan_loss(model)
- loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
+ loss_gp = train.gan_loss(model,
+ gradient_penalty_weight=1.0,
+ gradient_penalty_one_sided=one_sided)
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
# Check values.
@@ -394,6 +396,25 @@ class GANLossTest(test.TestCase):
def test_grad_penalty_callable_acgan(self):
self._test_grad_penalty_helper(create_callable_acgan_model)
+ def test_grad_penalty_one_sided_gan(self):
+ self._test_grad_penalty_helper(create_gan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_gan(self):
+ self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_infogan(self):
+ self._test_grad_penalty_helper(create_infogan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_infogan(self):
+ self._test_grad_penalty_helper(
+ create_callable_infogan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_acgan(self):
+ self._test_grad_penalty_helper(create_acgan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_acgan(self):
+ self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True)
+
# Test mutual information penalty option.
def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
train.gan_loss(create_gan_model_fn(),