diff options
-rw-r--r-- | tensorflow/contrib/gan/python/losses/python/losses_impl.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index 87fdb7cae4..29bd72d4db 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -170,8 +170,8 @@ def wasserstein_discriminator_loss( # ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs` # (https://arxiv.org/abs/1610.09585). def acgan_discriminator_loss( - discriminator_gen_classification_logits, discriminator_real_classification_logits, + discriminator_gen_classification_logits, one_hot_labels, label_smoothing=0.0, real_weights=1.0, @@ -192,10 +192,10 @@ def acgan_discriminator_loss( ACGAN: https://arxiv.org/abs/1610.09585 Args: - discriminator_gen_classification_logits: Classification logits for generated - data. discriminator_real_classification_logits: Classification logits for real data. + discriminator_gen_classification_logits: Classification logits for generated + data. one_hot_labels: A Tensor holding one-hot labels for the batch. label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for "discriminator on real data" as suggested in @@ -291,8 +291,8 @@ def acgan_generator_loss( # TODO(joelshor): Figure out why this function can't be inside a name scope. def wasserstein_gradient_penalty( - generated_data, real_data, + generated_data, generator_inputs, discriminator_fn, discriminator_scope, @@ -308,8 +308,8 @@ def wasserstein_gradient_penalty( (https://arxiv.org/abs/1704.00028) for more details. Args: - generated_data: Output of the generator. real_data: Real data. + generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. discriminator_fn: A discriminator function that conforms to TFGAN API. |