aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Ankur Taly <ataly@google.com>2018-02-16 18:22:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 18:27:19 -0800
commit0e6f39d1bd7fe8daa86944f6ab0dd94fbeb4962a (patch)
treeee0dabaff4147ecc9bc92acd2a50dadbfd694f39 /tensorflow/contrib/gan
parent128572c316e6f2eb6346f920314ef98e88e75069 (diff)
Merge changes from github.
PiperOrigin-RevId: 186073337
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_impl.py7
-rw-r--r--tensorflow/contrib/gan/python/eval/python/summaries_test.py14
2 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
index 74811ff409..0d1afad72d 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py
@@ -39,12 +39,13 @@ def _assert_is_image(data):
data.shape[1:].assert_is_fully_defined()
-def add_gan_model_image_summaries(gan_model, grid_size=4):
+def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True):
"""Adds image summaries for real and fake images.
Args:
gan_model: A GANModel tuple.
grid_size: The size of an image grid.
+ model_summaries: Also add summaries of the model.
Raises:
ValueError: If real and generated data aren't images.
@@ -83,7 +84,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4):
image_shape=generated_image_shape,
num_channels=generated_channels),
max_outputs=1)
- add_gan_model_summaries(gan_model)
+
+ if model_summaries:
+ add_gan_model_summaries(gan_model)
def add_image_comparison_summaries(gan_model, num_comparisons=2,
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
index a02d8772e1..5549df971d 100644
--- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py
@@ -71,9 +71,10 @@ def get_cyclegan_model():
class SummariesTest(test.TestCase):
- def _test_add_gan_model_image_summaries_impl(self, get_model_fn,
- expected_num_summary_ops):
- summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2)
+ def _test_add_gan_model_image_summaries_impl(
+ self, get_model_fn, expected_num_summary_ops, model_summaries):
+ summaries.add_gan_model_image_summaries(
+ get_model_fn(), grid_size=2, model_summaries=model_summaries)
self.assertEquals(expected_num_summary_ops,
len(ops.get_collection(ops.GraphKeys.SUMMARIES)))
@@ -82,10 +83,13 @@ class SummariesTest(test.TestCase):
summary.merge_all().eval()
def test_add_gan_model_image_summaries(self):
- self._test_add_gan_model_image_summaries_impl(get_gan_model, 5)
+ self._test_add_gan_model_image_summaries_impl(get_gan_model, 5, True)
+
+ def test_add_gan_model_image_summaries_no_model(self):
+ self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False)
def test_add_gan_model_image_summaries_for_cyclegan(self):
- self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10)
+ self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True)
def _test_add_gan_model_summaries_impl(self, get_model_fn,
expected_num_summary_ops):