aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
index 87fdb7cae4..29bd72d4db 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py
@@ -170,8 +170,8 @@ def wasserstein_discriminator_loss(
# ACGAN losses from `Conditional Image Synthesis With Auxiliary Classifier GANs`
# (https://arxiv.org/abs/1610.09585).
def acgan_discriminator_loss(
- discriminator_gen_classification_logits,
discriminator_real_classification_logits,
+ discriminator_gen_classification_logits,
one_hot_labels,
label_smoothing=0.0,
real_weights=1.0,
@@ -192,10 +192,10 @@ def acgan_discriminator_loss(
ACGAN: https://arxiv.org/abs/1610.09585
Args:
- discriminator_gen_classification_logits: Classification logits for generated
- data.
discriminator_real_classification_logits: Classification logits for real
data.
+ discriminator_gen_classification_logits: Classification logits for generated
+ data.
one_hot_labels: A Tensor holding one-hot labels for the batch.
label_smoothing: A float in [0, 1]. If greater than 0, smooth the labels for
"discriminator on real data" as suggested in
@@ -291,8 +291,8 @@ def acgan_generator_loss(
# TODO(joelshor): Figure out why this function can't be inside a name scope.
def wasserstein_gradient_penalty(
- generated_data,
real_data,
+ generated_data,
generator_inputs,
discriminator_fn,
discriminator_scope,
@@ -308,8 +308,8 @@ def wasserstein_gradient_penalty(
(https://arxiv.org/abs/1704.00028) for more details.
Args:
- generated_data: Output of the generator.
real_data: Real data.
+ generated_data: Output of the generator.
generator_inputs: Exact argument to pass to the generator, which is used
as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TFGAN API.