aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Wesley Qian <wwq@google.com>2018-07-25 11:34:39 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-25 11:40:47 -0700
commit8c6782ec4ed12dcdda1fdf8cb45ba8afbf62a61f (patch)
tree8ec77c959db757b32edb6f17ee42fc070c11a8ae /tensorflow/contrib/gan
parent96c76b296768852dac94aaf006beab2e637cbbb6 (diff)
Update test for StarGANModel to conform the original GANModel test.
PiperOrigin-RevId: 206027004
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/train_test.py57
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py
index fa52e9cca1..df8e0041a9 100644
--- a/tensorflow/contrib/gan/python/train_test.py
+++ b/tensorflow/contrib/gan/python/train_test.py
@@ -114,6 +114,12 @@ def stargan_generator_model(inputs, _):
return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs
+class StarGANGenerator(object):
+
+ def __call__(self, inputs, _):
+ return stargan_generator_model(inputs, _)
+
+
def stargan_discriminator_model(inputs, num_domains):
"""Differentiable dummy discriminator for StarGAN."""
@@ -130,6 +136,12 @@ def stargan_discriminator_model(inputs, num_domains):
return output_src, output_cls
+class StarGANDiscriminator(object):
+
+ def __call__(self, inputs, num_domains):
+ return stargan_discriminator_model(inputs, num_domains)
+
+
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
@@ -272,6 +284,49 @@ def create_callable_cyclegan_model():
data_y=array_ops.ones([1, 2]))
+def get_stargan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('generator') as gen_scope:
+ pass
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ return namedtuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=array_ops.ones([1, 2, 2, 3]),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]),
+ discriminator_generated_data_source_predication=array_ops.ones([1]),
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
+ generator_variables=None,
+ generator_scope=gen_scope,
+ generator_fn=stargan_generator_model,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=stargan_discriminator_model)
+
+
+def get_callable_stargan_model():
+ model = get_stargan_model()
+ return model._replace(
+ generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator())
+
+
+def create_stargan_model():
+ return train.stargan_model(
+ stargan_generator_model, stargan_discriminator_model,
+ array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2]))
+
+
+def create_callable_stargan_model():
+ return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(),
+ array_ops.ones([1, 2, 2, 3]),
+ array_ops.ones([1, 2]))
+
+
def get_sync_optimizer():
return sync_replicas_optimizer.SyncReplicasOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
@@ -292,6 +347,8 @@ class GANModelTest(test.TestCase, parameterized.TestCase):
('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel),
('callable_cyclegan', get_callable_cyclegan_model,
namedtuples.CycleGANModel),
+ ('stargan', get_stargan_model, namedtuples.StarGANModel),
+ ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel)
)
def test_output_type(self, create_fn, expected_tuple_type):
"""Test that output type is as expected."""