aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 09:25:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 09:28:10 -0700
commit43c6dd98f1a69c0515f0769b997cfac576a195e5 (patch)
tree64cf035f6241a0898ace04c8ceb1d11a7f705e48 /tensorflow/contrib/gan
parentf4672ca59b259436dd1cb60b9e12ba9c523e17f6 (diff)
Add CycleGAN specific summary.
PiperOrigin-RevId: 191302480
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py64
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py20
2 files changed, 59 insertions, 25 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 0d1afad72d..508f487722 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -31,6 +31,7 @@ __all__ = [
'add_image_comparison_summaries',
'add_gan_model_summaries',
'add_regularization_loss_summaries',
+ 'add_cyclegan_image_summaries',
]
@@ -51,14 +52,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True):
ValueError: If real and generated data aren't images.
"""
if isinstance(gan_model, namedtuples.CycleGANModel):
- saved_params = locals()
- saved_params.pop('gan_model', None)
- with ops.name_scope('cyclegan_x2y_image_summaries'):
- add_gan_model_image_summaries(gan_model.model_x2y, **saved_params)
- with ops.name_scope('cyclegan_y2x_image_summaries'):
- add_gan_model_image_summaries(gan_model.model_y2x, **saved_params)
- return
-
+ raise ValueError(
+ '`add_gan_model_image_summaries` does not take CycleGANModels. Please '
+ 'use `add_cyclegan_image_summaries` instead.')
_assert_is_image(gan_model.real_data)
_assert_is_image(gan_model.generated_data)
@@ -89,6 +85,49 @@ def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True):
add_gan_model_summaries(gan_model)
+def add_cyclegan_image_summaries(cyclegan_model):
+ """Adds image summaries for CycleGAN.
+
+ There are two summaries, one for each generator. The first image is the
+ generator input, the second is the generator output, and the third is G(F(x)).
+
+ Args:
+ cyclegan_model: A CycleGANModel tuple.
+
+ Raises:
+ ValueError: If `cyclegan_model` isn't a CycleGANModel.
+ ValueError: If generated data, generator inputs, and reconstructions aren't
+ images.
+ ValueError: If the generator input, generated data, and reconstructions
+ aren't all the same size.
+ """
+ if not isinstance(cyclegan_model, namedtuples.CycleGANModel):
+ raise ValueError('`cyclegan_model` was not a CycleGANModel. Instead, was '
+ '%s' % type(cyclegan_model))
+
+ _assert_is_image(cyclegan_model.model_x2y.generator_inputs)
+ _assert_is_image(cyclegan_model.model_x2y.generated_data)
+ _assert_is_image(cyclegan_model.reconstructed_x)
+ _assert_is_image(cyclegan_model.model_y2x.generator_inputs)
+ _assert_is_image(cyclegan_model.model_y2x.generated_data)
+ _assert_is_image(cyclegan_model.reconstructed_y)
+
+ def _add_comparison_summary(gan_model, reconstructions):
+ image_list = (array_ops.unstack(gan_model.generator_inputs[:1]) +
+ array_ops.unstack(gan_model.generated_data[:1]) +
+ array_ops.unstack(reconstructions[:1]))
+ summary.image(
+ 'image_comparison', eval_utils.image_reshaper(
+ image_list, num_cols=len(image_list)), max_outputs=1)
+
+ with ops.name_scope('x2y_image_comparison_summaries'):
+ _add_comparison_summary(
+ cyclegan_model.model_x2y, cyclegan_model.reconstructed_x)
+ with ops.name_scope('y2x_image_comparison_summaries'):
+ _add_comparison_summary(
+ cyclegan_model.model_y2x, cyclegan_model.reconstructed_y)
+
+
def add_image_comparison_summaries(gan_model, num_comparisons=2,
display_diffs=False):
"""Adds image summaries to compare triplets of images.
@@ -109,15 +148,6 @@ def add_image_comparison_summaries(gan_model, num_comparisons=2,
ValueError: If the generator input, real, and generated data aren't all the
same size.
"""
- if isinstance(gan_model, namedtuples.CycleGANModel):
- saved_params = locals()
- saved_params.pop('gan_model', None)
- with ops.name_scope('cyclegan_x2y_image_comparison_summaries'):
- add_image_comparison_summaries(gan_model.model_x2y, **saved_params)
- with ops.name_scope('cyclegan_y2x_image_comparison_summaries'):
- add_image_comparison_summaries(gan_model.model_y2x, **saved_params)
- return
-
_assert_is_image(gan_model.generator_inputs)
_assert_is_image(gan_model.generated_data)
_assert_is_image(gan_model.real_data)
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index 45eb108586..33d51bfc21 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -65,15 +65,14 @@ def get_cyclegan_model():
return namedtuples.CycleGANModel(
model_x2y=model_x2y,
model_y2x=model_y2x,
- reconstructed_x=array_ops.zeros([3, 30, 35, 6]),
- reconstructed_y=array_ops.zeros([3, 30, 35, 6]))
+ reconstructed_x=array_ops.zeros([4, 32, 32, 3]),
+ reconstructed_y=array_ops.zeros([4, 32, 32, 3]))
class SummariesTest(test.TestCase):
- def _test_add_gan_model_image_summaries_impl(self, get_model_fn,
- expected_num_summary_ops,
- model_summaries):
+ def _test_add_gan_model_image_summaries_impl(
+ self, get_model_fn, expected_num_summary_ops, model_summaries):
summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2,
model_summaries=model_summaries)
@@ -89,8 +88,9 @@ class SummariesTest(test.TestCase):
def test_add_gan_model_image_summaries_no_model(self):
self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False)
- def test_add_gan_model_image_summaries_for_cyclegan(self):
- self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True)
+ def test_cyclegan_image_summaries_dont_work(self):
+ with self.assertRaises(ValueError):
+ summaries.add_gan_model_image_summaries(get_cyclegan_model())
def _test_add_gan_model_summaries_impl(self, get_model_fn,
expected_num_summary_ops):
@@ -137,7 +137,11 @@ class SummariesTest(test.TestCase):
self._test_add_image_comparison_summaries_impl(get_gan_model, 1)
def test_add_image_comparison_summaries_for_cyclegan(self):
- self._test_add_image_comparison_summaries_impl(get_cyclegan_model, 2)
+ summaries.add_cyclegan_image_summaries(get_cyclegan_model())
+
+ self.assertEquals(2, len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
+ with self.test_session(use_gpu=True):
+ summary.merge_all().eval()
if __name__ == '__main__':