aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Wesley Qian <wwq@google.com>2018-08-13 11:26:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-13 11:30:02 -0700
commit28f5f7b58c15c33f90c639c46115b1d6581a7408 (patch)
tree68f9717fd3b71c2ca738dfb6d16adf446c909400 /tensorflow/contrib/gan
parentb4af892e0e037925c4376c7d40623d3635969f06 (diff)
Add stargan image summaries to show result of transforming image to each
domain. PiperOrigin-RevId: 208513708
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/BUILD2
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py91
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py40
3 files changed, 132 insertions, 1 deletions
diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD
index 82e3bbe3c0..9866fccfba 100644
--- a/tensorflow/contrib/gan/BUILD
+++ b/tensorflow/contrib/gan/BUILD
@@ -424,9 +424,11 @@ py_library(
":namedtuples",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:functional_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:summary",
"//tensorflow/python:util",
+ "//tensorflow/python:variable_scope",
"//tensorflow/python/ops/losses",
],
)
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 508f487722..f9995bb19d 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -22,7 +22,9 @@ from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import util as loss_util
from tensorflow.python.summary import summary
@@ -32,6 +34,7 @@ __all__ = [
'add_gan_model_summaries',
'add_regularization_loss_summaries',
'add_cyclegan_image_summaries',
+ 'add_stargan_image_summaries'
]
@@ -179,6 +182,94 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
max_outputs=1)
+def add_stargan_image_summaries(stargan_model,
+ num_images=2,
+ display_diffs=False):
+ """Adds image summaries to see StarGAN image results.
+
+ If display_diffs is True, each image result has `2` rows and `num_domains + 1`
+ columns.
+ The first row looks like:
+ [original_image, transformed_to_domain_0, transformed_to_domain_1, ...]
+ The second row looks like:
+ [no_modification_baseline, transformed_to_domain_0-original_image, ...]
+ If display_diffs is False, only the first row is shown.
+
+ IMPORTANT:
+ Since the model originally does not transformed the image to every domains,
+ we will transform them on-the-fly within this function in parallel.
+
+ Args:
+ stargan_model: A StarGANModel tuple.
+ num_images: The number of examples/images to be transformed and shown.
+ display_diffs: Also display the difference between generated and target.
+
+ Raises:
+ ValueError: If input_data is not images.
+ ValueError: If input_data_domain_label is not rank 2.
+ ValueError: If dimension 2 of input_data_domain_label is not fully defined.
+ """
+
+ _assert_is_image(stargan_model.input_data)
+ stargan_model.input_data_domain_label.shape.assert_has_rank(2)
+ stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1]
+
+ def _build_image(image):
+ """Helper function to create a result for each image on the fly."""
+
+ # Expand the first dimension as batch_size = 1.
+ images = array_ops.expand_dims(image, axis=0)
+
+ # Tile the image num_domains times, so we can get all transformed together.
+ images = array_ops.tile(images, [num_domains, 1, 1, 1])
+
+ # Create the targets to 0, 1, 2, ..., num_domains-1.
+ targets = array_ops.one_hot(list(range(num_domains)), num_domains)
+
+ with variable_scope.variable_scope(
+ stargan_model.generator_scope, reuse=True):
+
+ # Add the original image.
+ output_images_list = [image]
+
+ # Generate the image and add to the list.
+ gen_images = stargan_model.generator_fn(images, targets)
+ gen_images_list = array_ops.split(gen_images, num_domains)
+ gen_images_list = [
+ array_ops.squeeze(img, axis=0) for img in gen_images_list
+ ]
+ output_images_list.extend(gen_images_list)
+
+ # Display diffs.
+ if display_diffs:
+ diff_images = gen_images - images
+ diff_images_list = array_ops.split(diff_images, num_domains)
+ diff_images_list = [
+ array_ops.squeeze(img, axis=0) for img in diff_images_list
+ ]
+ output_images_list.append(array_ops.zeros_like(image))
+ output_images_list.extend(diff_images_list)
+
+ # Create the final image.
+ final_image = eval_utils.image_reshaper(
+ output_images_list, num_cols=num_domains + 1)
+
+ # Reduce the first rank.
+ return array_ops.squeeze(final_image, axis=0)
+
+ summary.image(
+ 'stargan_image_generation',
+ functional_ops.map_fn(
+ _build_image,
+ stargan_model.input_data[:num_images],
+ parallel_iterations=num_images,
+ back_prop=False,
+ swap_memory=True),
+ max_outputs=num_images)
+
+
def add_gan_model_summaries(gan_model):
"""Adds typical GANModel summaries.
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index 33d51bfc21..54a6f8d4d9 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -18,7 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import summaries_impl as summaries
from tensorflow.python.framework import ops
@@ -37,6 +36,10 @@ def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
+def stargan_generator_model(inputs, _):
+ return generator_model(inputs)
+
+
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
@@ -57,6 +60,31 @@ def get_gan_model():
discriminator_fn=discriminator_model)
+def get_stargan_model():
+ """Similar to get_gan_model()."""
+ # TODO(joelshor): Find a better way of creating a variable scope.
+ with variable_scope.variable_scope('discriminator') as dis_scope:
+ pass
+ with variable_scope.variable_scope('generator') as gen_scope:
+ return namedtuples.StarGANModel(
+ input_data=array_ops.ones([1, 2, 2, 3]),
+ input_data_domain_label=array_ops.ones([1, 2]),
+ generated_data=stargan_generator_model(
+ array_ops.ones([1, 2, 2, 3]), None),
+ generated_data_domain_target=array_ops.ones([1, 2]),
+ reconstructed_data=array_ops.ones([1, 2, 2, 3]),
+ discriminator_input_data_source_predication=array_ops.ones([1]),
+ discriminator_generated_data_source_predication=array_ops.ones([1]),
+ discriminator_input_data_domain_predication=array_ops.ones([1, 2]),
+ discriminator_generated_data_domain_predication=array_ops.ones([1, 2]),
+ generator_variables=None,
+ generator_scope=gen_scope,
+ generator_fn=stargan_generator_model,
+ discriminator_variables=None,
+ discriminator_scope=dis_scope,
+ discriminator_fn=discriminator_model)
+
+
def get_cyclegan_model():
with variable_scope.variable_scope('x2y'):
model_x2y = get_gan_model()
@@ -143,6 +171,16 @@ class SummariesTest(test.TestCase):
with self.test_session(use_gpu=True):
summary.merge_all().eval()
+ def test_add_image_comparison_summaries_for_stargan(self):
+
+ summaries.add_stargan_image_summaries(get_stargan_model())
+
+ self.assertEquals(1, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(variables.global_variables_initializer())
+ summary.merge_all().eval()
+
if __name__ == '__main__':
test.main()