aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/gan/python/eval/python/summaries_impl.py')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py91
1 files changed, 91 insertions, 0 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 508f487722..f9995bb19d 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -22,7 +22,9 @@ from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.eval.python import eval_utils
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import util as loss_util
from tensorflow.python.summary import summary
@@ -32,6 +34,7 @@ __all__ = [
'add_gan_model_summaries',
'add_regularization_loss_summaries',
'add_cyclegan_image_summaries',
+ 'add_stargan_image_summaries'
]
@@ -179,6 +182,94 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
max_outputs=1)
+def add_stargan_image_summaries(stargan_model,
+ num_images=2,
+ display_diffs=False):
+ """Adds image summaries to see StarGAN image results.
+
+ If display_diffs is True, each image result has `2` rows and `num_domains + 1`
+ columns.
+ The first row looks like:
+ [original_image, transformed_to_domain_0, transformed_to_domain_1, ...]
+ The second row looks like:
+ [no_modification_baseline, transformed_to_domain_0-original_image, ...]
+ If display_diffs is False, only the first row is shown.
+
+ IMPORTANT:
+ Since the model originally does not transformed the image to every domains,
+ we will transform them on-the-fly within this function in parallel.
+
+ Args:
+ stargan_model: A StarGANModel tuple.
+ num_images: The number of examples/images to be transformed and shown.
+ display_diffs: Also display the difference between generated and target.
+
+ Raises:
+ ValueError: If input_data is not images.
+ ValueError: If input_data_domain_label is not rank 2.
+ ValueError: If dimension 2 of input_data_domain_label is not fully defined.
+ """
+
+ _assert_is_image(stargan_model.input_data)
+ stargan_model.input_data_domain_label.shape.assert_has_rank(2)
+ stargan_model.input_data_domain_label.shape[1:].assert_is_fully_defined()
+
+ num_domains = stargan_model.input_data_domain_label.get_shape().as_list()[-1]
+
+ def _build_image(image):
+ """Helper function to create a result for each image on the fly."""
+
+ # Expand the first dimension as batch_size = 1.
+ images = array_ops.expand_dims(image, axis=0)
+
+ # Tile the image num_domains times, so we can get all transformed together.
+ images = array_ops.tile(images, [num_domains, 1, 1, 1])
+
+ # Create the targets to 0, 1, 2, ..., num_domains-1.
+ targets = array_ops.one_hot(list(range(num_domains)), num_domains)
+
+ with variable_scope.variable_scope(
+ stargan_model.generator_scope, reuse=True):
+
+ # Add the original image.
+ output_images_list = [image]
+
+ # Generate the image and add to the list.
+ gen_images = stargan_model.generator_fn(images, targets)
+ gen_images_list = array_ops.split(gen_images, num_domains)
+ gen_images_list = [
+ array_ops.squeeze(img, axis=0) for img in gen_images_list
+ ]
+ output_images_list.extend(gen_images_list)
+
+ # Display diffs.
+ if display_diffs:
+ diff_images = gen_images - images
+ diff_images_list = array_ops.split(diff_images, num_domains)
+ diff_images_list = [
+ array_ops.squeeze(img, axis=0) for img in diff_images_list
+ ]
+ output_images_list.append(array_ops.zeros_like(image))
+ output_images_list.extend(diff_images_list)
+
+ # Create the final image.
+ final_image = eval_utils.image_reshaper(
+ output_images_list, num_cols=num_domains + 1)
+
+ # Reduce the first rank.
+ return array_ops.squeeze(final_image, axis=0)
+
+ summary.image(
+ 'stargan_image_generation',
+ functional_ops.map_fn(
+ _build_image,
+ stargan_model.input_data[:num_images],
+ parallel_iterations=num_images,
+ back_prop=False,
+ swap_memory=True),
+ max_outputs=num_images)
+
+
def add_gan_model_summaries(gan_model):
"""Adds typical GANModel summaries.