aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Michael Case <mikecase@google.com>2018-04-10 18:44:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-10 18:46:38 -0700
commit5ad9e4588874f30d0d079acc60e07f2eddc0480f (patch)
treeab800846cc505d867b2961578869aec97eeb81a3 /tensorflow/contrib/gan
parentfad74785d12ea7463e5d0474522cd7d754699656 (diff)
Merge changes from github.
PiperOrigin-RevId: 192388250
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py4
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py14
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py22
-rw-r--r--tensorflow/contrib/gan/python/train.py4
-rw-r--r--tensorflow/contrib/gan/python/train_test.py25
5 files changed, 61 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index 082c42eba1..e3fc6bf0f0 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -88,8 +88,8 @@ class GANEstimator(estimator.Estimator):
discriminator_fn=discriminator_fn,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
- generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5),
- discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5))
+ generator_optimizer=tf.train.AdamOptimizer(0.1, 0.5),
+ discriminator_optimizer=tf.train.AdamOptimizer(0.1, 0.5))
# Train estimator.
gan_estimator.train(train_input_fn, steps)
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index 39588b7219..1ba3a64167 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -306,6 +306,7 @@ def wasserstein_gradient_penalty(
discriminator_scope,
epsilon=1e-10,
target=1.0,
+ one_sided=False,
weights=1.0,
scope=None,
loss_collection=ops.GraphKeys.LOSSES,
@@ -327,6 +328,8 @@ def wasserstein_gradient_penalty(
computing the gradient norm.
target: Optional Python number or `Tensor` indicating the target value of
gradient norm. Defaults to 1.0.
+ one_sided: If `True`, penalty proposed in https://arxiv.org/abs/1709.08894
+ is used. Defaults to `False`.
weights: Optional `Tensor` whose rank is either 0, or the same rank as
`real_data` and `generated_data`, and must be broadcastable to
them (i.e., all dimensions must be either `1`, or the same as the
@@ -377,10 +380,13 @@ def wasserstein_gradient_penalty(
# For numerical stability, add epsilon to the sum before taking the square
# root. Note tf.norm does not add epsilon.
slopes = math_ops.sqrt(gradient_squares + epsilon)
- penalties = math_ops.square(slopes / target - 1.0)
+ penalties = slopes / target - 1.0
+ if one_sided:
+ penalties = math_ops.maximum(0., penalties)
+ penalties_squared = math_ops.square(penalties)
penalty = losses.compute_weighted_loss(
- penalties, weights, scope=scope, loss_collection=loss_collection,
- reduction=reduction)
+ penalties_squared, weights, scope=scope,
+ loss_collection=loss_collection, reduction=reduction)
if add_summaries:
summary.scalar('gradient_penalty_loss', penalty)
@@ -665,7 +671,7 @@ def least_squares_discriminator_loss(
loss_collection=ops.GraphKeys.LOSSES,
reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
add_summaries=False):
- """Least squares generator loss.
+ """Least squares discriminator loss.
This loss comes from `Least Squares Generative Adversarial Networks`
(https://arxiv.org/abs/1611.04076).
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index dbaa624ae9..2889e93743 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -481,6 +481,28 @@ class GradientPenaltyTest(test.TestCase, _PenaltyTest):
})
self.assertAlmostEqual(self._expected_loss, loss, 5)
+ def test_loss_using_one_sided_mode(self):
+ generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+ real_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
+
+ loss = tfgan_losses.wasserstein_gradient_penalty(
+ generated_data,
+ real_data,
+ self._kwargs['generator_inputs'],
+ self._kwargs['discriminator_fn'],
+ self._kwargs['discriminator_scope'],
+ one_sided=True)
+ self.assertEqual(generated_data.dtype, loss.dtype)
+
+ with self.test_session() as sess:
+ variables.global_variables_initializer().run()
+ loss = sess.run(loss,
+ feed_dict={
+ generated_data: self._generated_data_np,
+ real_data: self._real_data_np,
+ })
+ self.assertAlmostEqual(self._expected_loss, loss, 5)
+
def test_loss_with_gradient_norm_target(self):
"""Test loss value with non default gradient norm target."""
generated_data = array_ops.placeholder(dtypes.float32, shape=(None, None))
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 776eb11ecb..73acd05b60 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -461,6 +461,7 @@ def gan_loss(
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
gradient_penalty_target=1.0,
+ gradient_penalty_one_sided=False,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
@@ -485,6 +486,8 @@ def gan_loss(
gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
number or `Tensor` indicating the target value of gradient norm. See the
CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
+ gradient_penalty_one_sided: If `True`, penalty proposed in
+ https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
@@ -546,6 +549,7 @@ def gan_loss(
model,
epsilon=gradient_penalty_epsilon,
target=gradient_penalty_target,
+ one_sided=gradient_penalty_one_sided,
add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index f9bdaa74c9..3ebbe55d05 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -359,10 +359,12 @@ class GANLossTest(test.TestCase):
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
# Test gradient penalty option.
- def _test_grad_penalty_helper(self, create_gan_model_fn):
+ def _test_grad_penalty_helper(self, create_gan_model_fn, one_sided=False):
model = create_gan_model_fn()
loss = train.gan_loss(model)
- loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
+ loss_gp = train.gan_loss(model,
+ gradient_penalty_weight=1.0,
+ gradient_penalty_one_sided=one_sided)
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
# Check values.
@@ -394,6 +396,25 @@ class GANLossTest(test.TestCase):
def test_grad_penalty_callable_acgan(self):
self._test_grad_penalty_helper(create_callable_acgan_model)
+ def test_grad_penalty_one_sided_gan(self):
+ self._test_grad_penalty_helper(create_gan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_gan(self):
+ self._test_grad_penalty_helper(create_callable_gan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_infogan(self):
+ self._test_grad_penalty_helper(create_infogan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_infogan(self):
+ self._test_grad_penalty_helper(
+ create_callable_infogan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_acgan(self):
+ self._test_grad_penalty_helper(create_acgan_model, one_sided=True)
+
+ def test_grad_penalty_one_sided_callable_acgan(self):
+ self._test_grad_penalty_helper(create_callable_acgan_model, one_sided=True)
+
# Test mutual information penalty option.
def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
train.gan_loss(create_gan_model_fn(),