aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/eval/python/summaries_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/eval/python/summaries_test.py')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py40
1 files changed, 39 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index 33d51bfc21..54a6f8d4d9 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
from tensorflow.python.framework import ops
@@ -37,6 +36,10 @@ def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
+def stargan_generator_model(inputs, _):
+ return generator_model(inputs)
+
+
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
@@ -57,6 +60,31 @@ def get_gan_model():
discriminator_fn=discriminator_model)
+def get_stargan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ with variable_scope.variable_scope('generator') as gen_scope:
+ return namedtuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=stargan_generator_model(
+ array_ops.ones([1, 2, 2, 3]), None),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]),
+ discriminator_generated_data_source_predication=array_ops.ones([1]),
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
+ generator_variables=None,
+ generator_scope=gen_scope,
+ generator_fn=stargan_generator_model,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=discriminator_model)
+
+
def get_cyclegan_model():
with variable_scope.variable_scope('x2y'):
model_x2y = get_gan_model()
@@ -143,6 +171,16 @@ class SummariesTest(test.TestCase):
with self.test_session(use_gpu=True):
summary.merge_all().eval()
+ def test_add_image_comparison_summaries_for_stargan(self):
+
+ summaries.add_stargan_image_summaries(get_stargan_model())
+
+ self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ summary.merge_all().eval()
+
if __name__ == '__main__':
test.main()