aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/train.py')
-rw-r--r--tensorflow/contrib/gan/python/train.py128
1 files changed, 128 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index df603d1f18..03f52d214b 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -34,6 +34,7 @@ from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
+from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
from tensorflow.python.framework import dtypes
@@ -41,14 +42,17 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
+from tensorflow.python.summary import summary
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
+
__all__ = [
'gan_model',
'infogan_model',
@@ -751,6 +755,130 @@ def cyclegan_loss(
return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
+def stargan_loss(
+ model,
+ generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_generator_loss),
+ discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(
+ tfgan_losses_impl.wasserstein_discriminator_loss),
+ gradient_penalty_weight=10.0,
+ gradient_penalty_epsilon=1e-10,
+ gradient_penalty_target=1.0,
+ gradient_penalty_one_sided=False,
+ reconstruction_loss_fn=losses.absolute_difference,
+ reconstruction_loss_weight=10.0,
+ classification_loss_fn=losses.softmax_cross_entropy,
+ classification_loss_weight=1.0,
+ classification_one_hot=True,
+ add_summaries=True):
+ """StarGAN Loss.
+
+ The four major part can be found here: http://screen/tMRMBAohDYG.
+
+ Args:
+ model: (StarGAN) Model output of the stargan_model() function call.
+ generator_loss_fn: The loss function on the generator. Takes a
+ `StarGANModel` named tuple.
+ discriminator_loss_fn: The loss function on the discriminator. Takes a
+ `StarGANModel` namedtuple.
+ gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
+ the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
+ turn off gradient penalty.
+ gradient_penalty_epsilon: (float) A small positive number added for
+ numerical stability when computing the gradient norm.
+ gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
+ gradient norm. Defaults to 1.0.
+ gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
+ https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
+ reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
+ and the function must conform to the `tf.losses` API.
+ reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
+ classification_loss_fn: The loss function on the discriminator's ability to
+ classify domain of the input. Default to one-hot softmax cross entropy
+ loss, and the function must conform to the `tf.losses` API.
+ classification_loss_weight: (float) Classification loss weight. Default to
+ 1.0.
+ classification_one_hot: (bool) If the label is one hot representation.
+ Default to True. If False, classification classification_loss_fn need to
+ be sigmoid cross entropy loss instead.
+ add_summaries: (bool) Add the loss to the summary
+
+ Returns:
+ GANLoss namedtuple where we have generator loss and discriminator loss.
+
+ Raises:
+ ValueError: If input StarGANModel.input_data_domain_label does not have rank
+ 2, or dimension 2 is not defined.
+ """
+
+ def _classification_loss_helper(true_labels, predict_logits, scope_name):
+ """Classification Loss Function Helper.
+
+ Args:
+ true_labels: Tensor of shape [batch_size, num_domains] representing the
+ label where each row is an one-hot vector.
+ predict_logits: Tensor of shape [batch_size, num_domains] representing the
+ predicted label logit, which is UNSCALED output from the NN.
+ scope_name: (string) Name scope of the loss component.
+
+ Returns:
+ Single scalar tensor representing the classification loss.
+ """
+
+ with ops.name_scope(scope_name, values=(true_labels, predict_logits)):
+
+ loss = classification_loss_fn(
+ onehot_labels=true_labels, logits=predict_logits)
+
+ if not classification_one_hot:
+ loss = math_ops.reduce_sum(loss, axis=1)
+ loss = math_ops.reduce_mean(loss)
+
+ if add_summaries:
+ summary.scalar(scope_name, loss)
+
+ return loss
+
+ # Check input shape.
+ model.input_data_domain_label.shape.assert_has_rank(2)
+ model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ # Adversarial Loss.
+ generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
+ discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
+
+ # Gradient Penalty.
+ if _use_aux_loss(gradient_penalty_weight):
+ gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper(
+ tfgan_losses_impl.wasserstein_gradient_penalty)
+ discriminator_loss += gradient_penalty_fn(
+ model,
+ epsilon=gradient_penalty_epsilon,
+ target=gradient_penalty_target,
+ one_sided=gradient_penalty_one_sided,
+ add_summaries=add_summaries) * gradient_penalty_weight
+
+ # Reconstruction Loss.
+ reconstruction_loss = reconstruction_loss_fn(model.input_data,
+ model.reconstructed_data)
+ generator_loss += reconstruction_loss * reconstruction_loss_weight
+ if add_summaries:
+ summary.scalar('reconstruction_loss', reconstruction_loss)
+
+ # Classification Loss.
+ generator_loss += _classification_loss_helper(
+ true_labels=model.generated_data_domain_target,
+ predict_logits=model.discriminator_generated_data_domain_predication,
+ scope_name='generator_classification_loss') * classification_loss_weight
+ discriminator_loss += _classification_loss_helper(
+ true_labels=model.input_data_domain_label,
+ predict_logits=model.discriminator_input_data_domain_predication,
+ scope_name='discriminator_classification_loss'
+ ) * classification_loss_weight
+
+ return namedtuples.GANLoss(generator_loss, discriminator_loss)
+
+
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.