aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Wesley Qian <wwq@google.com>2018-07-29 15:56:45 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-29 16:00:23 -0700
commite7475504094b3018973740df470d2c9ee73a4fd5 (patch)
tree9c561b0ae6c3e0b49978a480a1722b4d8571d4c9 /tensorflow/contrib/gan
parentfd5050232b62e8d9c7b7a205c60b4e19071dee55 (diff)
Add StarGAN Loss to TFGAN.
- stargan_loss is added to train.py as other xgan_loss. - StarGANModel wrappers for generator_loss, discrminator_loss, and gradient_penalty are defined in losses/tuple_losses_impl.py. PiperOrigin-RevId: 206509200
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD5
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py86
-rw-r--r--tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py114
-rw-r--r--tensorflow/contrib/gan/python/train.py128
-rw-r--r--tensorflow/contrib/gan/python/train_test.py21
5 files changed, 354 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 7e6cb72485..053d4e3e97 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -196,11 +196,16 @@ py_test(
srcs = ["python/losses/python/tuple_losses_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":losses_impl",
":namedtuples",
":tuple_losses",
+ "//tensorflow/contrib/layers:layers_py",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
index dcc3f94c2d..221c70c38b 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py
@@ -80,6 +80,9 @@ __all__ = [
'mutual_information_penalty',
'combine_adversarial_loss',
'cycle_consistency_loss',
+ 'stargan_generator_loss_wrapper',
+ 'stargan_discriminator_loss_wrapper',
+ 'stargan_gradient_penalty_wrapper'
]
@@ -277,3 +280,86 @@ def cycle_consistency_loss(cyclegan_model, scope=None, add_summaries=False):
cyclegan_model.model_x2y.generator_inputs, cyclegan_model.reconstructed_x,
cyclegan_model.model_y2x.generator_inputs, cyclegan_model.reconstructed_y,
scope, add_summaries)
+
+
+def stargan_generator_loss_wrapper(loss_fn):
+ """Convert a generator loss function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking Discriminator's real/fake prediction for
+ generated data.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ return loss_fn(
+ stargan_model.discriminator_generated_data_source_predication, **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
+
+
+def stargan_discriminator_loss_wrapper(loss_fn):
+ """Convert a discriminator loss function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking Discriminator's real/fake prediction for
+ real data and generated data.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ return loss_fn(
+ stargan_model.discriminator_input_data_source_predication,
+ stargan_model.discriminator_generated_data_source_predication, **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
+
+
+def stargan_gradient_penalty_wrapper(loss_fn):
+ """Convert a gradient penalty function to take a StarGANModel.
+
+ The new function has the same name as the original one.
+
+ Args:
+ loss_fn: A python function taking real_data, generated_data,
+ generator_inputs for Discriminator's condition (i.e. number of domains),
+ discriminator_fn, and discriminator_scope.
+
+ Returns:
+ A new function that takes a StarGANModel namedtuple and returns the same
+ loss.
+ """
+
+ def new_loss_fn(stargan_model, **kwargs):
+ num_domains = stargan_model.input_data_domain_label.shape.as_list()[-1]
+ return loss_fn(
+ real_data=stargan_model.input_data,
+ generated_data=stargan_model.generated_data,
+ generator_inputs=num_domains,
+ discriminator_fn=stargan_model.discriminator_fn,
+ discriminator_scope=stargan_model.discriminator_scope,
+ **kwargs)
+
+ new_docstring = """The stargan_model version of %s.""" % loss_fn.__name__
+ new_loss_fn.__docstring__ = new_docstring
+ new_loss_fn.__name__ = loss_fn.__name__
+ new_loss_fn.__module__ = loss_fn.__module__
+ return new_loss_fn
diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
index aa1ef11172..a559bbfa11 100644
--- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_test.py
@@ -22,10 +22,15 @@ import collections
import numpy as np
+from tensorflow.contrib import layers
from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
from tensorflow.contrib.gan.python.losses.python import tuple_losses_impl as tfgan_losses
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -129,6 +134,9 @@ manual_tests = [
'mutual_information_penalty',
'wasserstein_gradient_penalty',
'cycle_consistency_loss',
+ 'stargan_generator_loss_wrapper',
+ 'stargan_discriminator_loss_wrapper',
+ 'stargan_gradient_penalty_wrapper'
]
discriminator_keyword_args = {
@@ -175,6 +183,112 @@ class CycleConsistencyLossTest(test.TestCase):
self.assertNear(5.0, loss.eval(), 1e-5)
+class StarGANLossWrapperTest(test.TestCase):
+
+ def setUp(self):
+
+ super(StarGANLossWrapperTest, self).setUp()
+
+ self.input_data = array_ops.ones([1, 2, 2, 3])
+ self.input_data_domain_label = constant_op.constant([[0, 1]])
+ self.generated_data = array_ops.ones([1, 2, 2, 3])
+ self.discriminator_input_data_source_predication = array_ops.ones([1])
+ self.discriminator_generated_data_source_predication = array_ops.ones([1])
+
+ def _discriminator_fn(inputs, num_domains):
+ """Differentiable dummy discriminator for StarGAN."""
+ hidden = layers.flatten(inputs)
+ output_src = math_ops.reduce_mean(hidden, axis=1)
+ output_cls = layers.fully_connected(
+ inputs=hidden,
+ num_outputs=num_domains,
+ activation_fn=None,
+ normalizer_fn=None,
+ biases_initializer=None)
+ return output_src, output_cls
+
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+
+ self.model = namedtuples.StarGANModel(
+ input_data=self.input_data,
+ input_data_domain_label=self.input_data_domain_label,
+ generated_data=self.generated_data,
+ generated_data_domain_target=None,
+ reconstructed_data=None,
+ discriminator_input_data_source_predication=self.
+ discriminator_input_data_source_predication,
+ discriminator_generated_data_source_predication=self.
+ discriminator_generated_data_source_predication,
+ discriminator_input_data_domain_predication=None,
+ discriminator_generated_data_domain_predication=None,
+ generator_variables=None,
+ generator_scope=None,
+ generator_fn=None,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=_discriminator_fn)
+
+ self.discriminator_fn = _discriminator_fn
+ self.discriminator_scope = dis_scope
+
+ def test_stargan_generator_loss_wrapper(self):
+ """Test StarGAN generator loss wrapper."""
+ loss_fn = tfgan_losses_impl.wasserstein_generator_loss
+ wrapped_loss_fn = tfgan_losses.stargan_generator_loss_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ self.discriminator_generated_data_source_predication)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+ def test_stargan_discriminator_loss_wrapper(self):
+ """Test StarGAN discriminator loss wrapper."""
+ loss_fn = tfgan_losses_impl.wasserstein_discriminator_loss
+ wrapped_loss_fn = tfgan_losses.stargan_discriminator_loss_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ self.discriminator_generated_data_source_predication,
+ self.discriminator_generated_data_source_predication)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+ def test_stargan_gradient_penalty_wrapper(self):
+ """Test StaGAN gradient penalty wrapper.
+
+ Notes:
+ The random interpolates are handled by given setting the reconstruction to
+ be the same as the input.
+
+ """
+ loss_fn = tfgan_losses_impl.wasserstein_gradient_penalty
+ wrapped_loss_fn = tfgan_losses.stargan_gradient_penalty_wrapper(loss_fn)
+
+ loss_result_tensor = loss_fn(
+ real_data=self.input_data,
+ generated_data=self.generated_data,
+ generator_inputs=self.input_data_domain_label.shape.as_list()[-1],
+ discriminator_fn=self.discriminator_fn,
+ discriminator_scope=self.discriminator_scope)
+ wrapped_loss_result_tensor = wrapped_loss_fn(self.model)
+
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ loss_result, wrapped_loss_result = sess.run(
+ [loss_result_tensor, wrapped_loss_result_tensor])
+ self.assertAlmostEqual(loss_result, wrapped_loss_result)
+
+
if __name__ == '__main__':
for loss_name in tfgan_losses.__all__:
if loss_name in manual_tests: continue
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index df603d1f18..03f52d214b 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -34,6 +34,7 @@ from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
from tensorflow.python.framework import dtypes
@@ -41,14 +42,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
+
__all__ = [
'gan_model',
'infogan_model',
@@ -751,6 +755,130 @@ def cyclegan_loss(
return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
+def stargan_loss(
+ model,
+ generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_generator_loss),
+ discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_discriminator_loss),
+ gradient_penalty_weight=10.0,
+ gradient_penalty_epsilon=1e-10,
+ gradient_penalty_target=1.0,
+ gradient_penalty_one_sided=False,
+ reconstruction_loss_fn=losses.absolute_difference,
+ reconstruction_loss_weight=10.0,
+ classification_loss_fn=losses.softmax_cross_entropy,
+ classification_loss_weight=1.0,
+ classification_one_hot=True,
+ add_summaries=True):
+ """StarGAN Loss.
+
+ The four major part can be found here: http://screen/tMRMBAohDYG.
+
+ Args:
+ model: (StarGAN) Model output of the stargan_model() function call.
+ generator_loss_fn: The loss function on the generator. Takes a
+ `StarGANModel` named tuple.
+ discriminator_loss_fn: The loss function on the discriminator. Takes a
+ `StarGANModel` namedtuple.
+ gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
+ the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
+ turn off gradient penalty.
+ gradient_penalty_epsilon: (float) A small positive number added for
+ numerical stability when computing the gradient norm.
+ gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
+ gradient norm. Defaults to 1.0.
+ gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
+ https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
+ reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
+ and the function must conform to the `tf.losses` API.
+ reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
+ classification_loss_fn: The loss function on the discriminator's ability to
+ classify domain of the input. Default to one-hot softmax cross entropy
+ loss, and the function must conform to the `tf.losses` API.
+ classification_loss_weight: (float) Classification loss weight. Default to
+ 1.0.
+ classification_one_hot: (bool) If the label is one hot representation.
+ Default to True. If False, classification classification_loss_fn need to
+ be sigmoid cross entropy loss instead.
+ add_summaries: (bool) Add the loss to the summary
+
+ Returns:
+ GANLoss namedtuple where we have generator loss and discriminator loss.
+
+ Raises:
+ ValueError: If input StarGANModel.input_data_domain_label does not have rank
+ 2, or dimension 2 is not defined.
+ """
+
+ def _classification_loss_helper(true_labels, predict_logits, scope_name):
+ """Classification Loss Function Helper.
+
+ Args:
+ true_labels: Tensor of shape [batch_size, num_domains] representing the
+ label where each row is an one-hot vector.
+ predict_logits: Tensor of shape [batch_size, num_domains] representing the
+ predicted label logit, which is UNSCALED output from the NN.
+ scope_name: (string) Name scope of the loss component.
+
+ Returns:
+ Single scalar tensor representing the classification loss.
+ """
+
+ with ops.name_scope(scope_name, values=(true_labels, predict_logits)):
+
+ loss = classification_loss_fn(
+ onehot_labels=true_labels, logits=predict_logits)
+
+ if not classification_one_hot:
+ loss = math_ops.reduce_sum(loss, axis=1)
+ loss = math_ops.reduce_mean(loss)
+
+ if add_summaries:
+ summary.scalar(scope_name, loss)
+
+ return loss
+
+ # Check input shape.
+ model.input_data_domain_label.shape.assert_has_rank(2)
+ model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ # Adversarial Loss.
+ generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
+ discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
+
+ # Gradient Penalty.
+ if _use_aux_loss(gradient_penalty_weight):
+ gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper(
+ tfgan_losses_impl.wasserstein_gradient_penalty)
+ discriminator_loss += gradient_penalty_fn(
+ model,
+ epsilon=gradient_penalty_epsilon,
+ target=gradient_penalty_target,
+ one_sided=gradient_penalty_one_sided,
+ add_summaries=add_summaries) * gradient_penalty_weight
+
+ # Reconstruction Loss.
+ reconstruction_loss = reconstruction_loss_fn(model.input_data,
+ model.reconstructed_data)
+ generator_loss += reconstruction_loss * reconstruction_loss_weight
+ if add_summaries:
+ summary.scalar('reconstruction_loss', reconstruction_loss)
+
+ # Classification Loss.
+ generator_loss += _classification_loss_helper(
+ true_labels=model.generated_data_domain_target,
+ predict_logits=model.discriminator_generated_data_domain_predication,
+ scope_name='generator_classification_loss') * classification_loss_weight
+ discriminator_loss += _classification_loss_helper(
+ true_labels=model.input_data_domain_label,
+ predict_logits=model.discriminator_input_data_domain_predication,
+ scope_name='discriminator_classification_loss'
+ ) * classification_loss_weight
+
+ return namedtuples.GANLoss(generator_loss, discriminator_loss)
+
+
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index df8e0041a9..58f348034f 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -666,6 +666,27 @@ class GANLossTest(test.TestCase, parameterized.TestCase):
self.assertTrue(np.isscalar(loss_y2x_dis_np))
@parameterized.named_parameters(
+ ('notcallable', create_stargan_model),
+ ('callable', create_callable_stargan_model),
+ )
+ def test_stargan(self, create_gan_model_fn):
+
+ model = create_gan_model_fn()
+ model_loss = train.stargan_loss(model)
+
+ self.assertIsInstance(model_loss, namedtuples.GANLoss)
+
+ with self.test_session() as sess:
+
+ sess.run(variables.global_variables_initializer())
+
+ gen_loss, disc_loss = sess.run(
+ [model_loss.generator_loss, model_loss.discriminator_loss])
+
+ self.assertTrue(np.isscalar(gen_loss))
+ self.assertTrue(np.isscalar(disc_loss))
+
+ @parameterized.named_parameters(
('gan', create_gan_model),
('callable_gan', create_callable_gan_model),
('infogan', create_infogan_model),