aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/losses/python/losses_impl_test.py')
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index dbaa624ae9..2889e93743 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -481,6 +481,28 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
})
self.assertAlmostEqual(self._expected_loss, loss, 5)
+ def test_loss_using_one_sided_mode(self):
+ generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+ real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+
+ loss = tfgan_losses.wasserstein_gradient_penalty(
+ generated_data,
+ real_data,
+ self._kwargs['generator_inputs'],
+ self._kwargs['discriminator_fn'],
+ self._kwargs['discriminator_scope'],
+ one_sided=True)
+ self.assertEqual(generated_data.dtype, loss.dtype)
+
+ with self.test_session() as sess:
+ variables.global_variables_initializer().run()
+ loss = sess.run(loss,
+ feed_dict={
+ generated_data: self._generated_data_np,
+ real_data: self._real_data_np,
+ })
+ self.assertAlmostEqual(self._expected_loss, loss, 5)
+
def test_loss_with_gradient_norm_target(self):
"""Test loss value with non default gradient norm target."""
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))