aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/namedtuples.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/namedtuples.py')
-rw-r--r--tensorflow/contrib/gan/python/namedtuples.py50
1 files changed, 49 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py
index 25cfeafeec..a462b68e28 100644
--- a/tensorflow/contrib/gan/python/namedtuples.py
+++ b/tensorflow/contrib/gan/python/namedtuples.py
@@ -25,12 +25,12 @@ from __future__ import print_function
import collections
-
__all__ = [
'GANModel',
'InfoGANModel',
'ACGANModel',
'CycleGANModel',
+ 'StarGANModel',
'GANLoss',
'CycleGANLoss',
'GANTrainOps',
@@ -136,6 +136,54 @@ class CycleGANModel(
"""
+class StarGANModel(
+ collections.namedtuple('StarGANModel', (
+ 'input_data',
+ 'input_data_domain_label',
+ 'generated_data',
+ 'generated_data_domain_target',
+ 'reconstructed_data',
+ 'discriminator_input_data_source_predication',
+ 'discriminator_generated_data_source_predication',
+ 'discriminator_input_data_domain_predication',
+ 'discriminator_generated_data_domain_predication',
+ 'generator_variables',
+ 'generator_scope',
+ 'generator_fn',
+ 'discriminator_variables',
+ 'discriminator_scope',
+ 'discriminator_fn',
+ ))):
+ """A StarGANModel contains all the pieces needed for StarGAN training.
+
+ Args:
+ input_data: The real images that need to be transferred by the generator.
+ input_data_domain_label: The real domain labels associated with the real
+ images.
+ generated_data: The generated images produced by the generator. It has the
+ same shape as the input_data.
+ generated_data_domain_target: The target domain that the generated images
+ belong to. It has the same shape as the input_data_domain_label.
+ reconstructed_data: The reconstructed images produced by the G(enerator).
+ reconstructed_data = G(G(input_data, generated_data_domain_target),
+ input_data_domain_label).
+ discriminator_input_data_source: The discriminator's output for predicting
+ the source (real/generated) of input_data.
+ discriminator_generated_data_source: The discriminator's output for
+ predicting the source (real/generated) of generated_data.
+ discriminator_input_data_domain_predication: The discriminator's output for
+ predicting the domain_label for the input_data.
+ discriminator_generated_data_domain_predication: The discriminatorr's output
+ for predicting the domain_target for the generated_data.
+ generator_variables: A list of all generator variables.
+ generator_scope: Variable scope all generator variables live in.
+ generator_fn: The generator function.
+ discriminator_variables: A list of all discriminator variables.
+ discriminator_scope: Variable scope all discriminator variables live in.
+ discriminator_fn: The discriminator function.
+ """
+
+
class GANLoss(
collections.namedtuple('GANLoss', (
'generator_loss',