diff options
author | Wesley Qian <wwq@google.com> | 2018-07-25 11:34:39 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-25 11:40:47 -0700 |
commit | 8c6782ec4ed12dcdda1fdf8cb45ba8afbf62a61f (patch) | |
tree | 8ec77c959db757b32edb6f17ee42fc070c11a8ae /tensorflow/contrib/gan | |
parent | 96c76b296768852dac94aaf006beab2e637cbbb6 (diff) |
Update test for StarGANModel to conform the original GANModel test.
PiperOrigin-RevId: 206027004
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/train_test.py | 57 |
1 files changed, 57 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py index fa52e9cca1..df8e0041a9 100644 --- a/tensorflow/contrib/gan/python/train_test.py +++ b/tensorflow/contrib/gan/python/train_test.py @@ -114,6 +114,12 @@ def stargan_generator_model(inputs, _): return variable_scope.get_variable('dummy_g', initializer=0.5) * inputs +class StarGANGenerator(object): + + def __call__(self, inputs, _): + return stargan_generator_model(inputs, _) + + def stargan_discriminator_model(inputs, num_domains): """Differentiable dummy discriminator for StarGAN.""" @@ -130,6 +136,12 @@ def stargan_discriminator_model(inputs, num_domains): return output_src, output_cls +class StarGANDiscriminator(object): + + def __call__(self, inputs, num_domains): + return stargan_discriminator_model(inputs, num_domains) + + def get_gan_model(): # TODO(joelshor): Find a better way of creating a variable scope. with variable_scope.variable_scope('generator') as gen_scope: @@ -272,6 +284,49 @@ def create_callable_cyclegan_model(): data_y=array_ops.ones([1, 2])) +def get_stargan_model(): + """Similar to get_gan_model().""" + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + pass + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + return namedtuples.StarGANModel( + input_data=array_ops.ones([1, 2, 2, 3]), + input_data_domain_label=array_ops.ones([1, 2]), + generated_data=array_ops.ones([1, 2, 2, 3]), + generated_data_domain_target=array_ops.ones([1, 2]), + reconstructed_data=array_ops.ones([1, 2, 2, 3]), + discriminator_input_data_source_predication=array_ops.ones([1]), + discriminator_generated_data_source_predication=array_ops.ones([1]), + discriminator_input_data_domain_predication=array_ops.ones([1, 2]), + discriminator_generated_data_domain_predication=array_ops.ones([1, 2]), + generator_variables=None, + generator_scope=gen_scope, + generator_fn=stargan_generator_model, + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=stargan_discriminator_model) + + +def get_callable_stargan_model(): + model = get_stargan_model() + return model._replace( + generator_fn=StarGANGenerator(), discriminator_fn=StarGANDiscriminator()) + + +def create_stargan_model(): + return train.stargan_model( + stargan_generator_model, stargan_discriminator_model, + array_ops.ones([1, 2, 2, 3]), array_ops.ones([1, 2])) + + +def create_callable_stargan_model(): + return train.stargan_model(StarGANGenerator(), StarGANDiscriminator(), + array_ops.ones([1, 2, 2, 3]), + array_ops.ones([1, 2])) + + def get_sync_optimizer(): return sync_replicas_optimizer.SyncReplicasOptimizer( gradient_descent.GradientDescentOptimizer(learning_rate=1.0), @@ -292,6 +347,8 @@ class GANModelTest(test.TestCase, parameterized.TestCase): ('cyclegan', get_cyclegan_model, namedtuples.CycleGANModel), ('callable_cyclegan', get_callable_cyclegan_model, namedtuples.CycleGANModel), + ('stargan', get_stargan_model, namedtuples.StarGANModel), + ('callabel_stargan', get_callable_stargan_model, namedtuples.StarGANModel) ) def test_output_type(self, create_fn, expected_tuple_type): """Test that output type is as expected.""" |