diff options
author | Ankur Taly <ataly@google.com> | 2018-02-16 18:22:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-16 18:27:19 -0800 |
commit | 0e6f39d1bd7fe8daa86944f6ab0dd94fbeb4962a (patch) | |
tree | ee0dabaff4147ecc9bc92acd2a50dadbfd694f39 /tensorflow/contrib/gan | |
parent | 128572c316e6f2eb6346f920314ef98e88e75069 (diff) |
Merge changes from github.
PiperOrigin-RevId: 186073337
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/summaries_impl.py | 7 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/summaries_test.py | 14 |
2 files changed, 14 insertions, 7 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py index 74811ff409..0d1afad72d 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_impl.py @@ -39,12 +39,13 @@ def _assert_is_image(data): data.shape[1:].assert_is_fully_defined() -def add_gan_model_image_summaries(gan_model, grid_size=4): +def add_gan_model_image_summaries(gan_model, grid_size=4, model_summaries=True): """Adds image summaries for real and fake images. Args: gan_model: A GANModel tuple. grid_size: The size of an image grid. + model_summaries: Also add summaries of the model. Raises: ValueError: If real and generated data aren't images. @@ -83,7 +84,9 @@ def add_gan_model_image_summaries(gan_model, grid_size=4): image_shape=generated_image_shape, num_channels=generated_channels), max_outputs=1) - add_gan_model_summaries(gan_model) + + if model_summaries: + add_gan_model_summaries(gan_model) def add_image_comparison_summaries(gan_model, num_comparisons=2, diff --git a/tensorflow/contrib/gan/python/eval/python/summaries_test.py b/tensorflow/contrib/gan/python/eval/python/summaries_test.py index a02d8772e1..5549df971d 100644 --- a/tensorflow/contrib/gan/python/eval/python/summaries_test.py +++ b/tensorflow/contrib/gan/python/eval/python/summaries_test.py @@ -71,9 +71,10 @@ def get_cyclegan_model(): class SummariesTest(test.TestCase): - def _test_add_gan_model_image_summaries_impl(self, get_model_fn, - expected_num_summary_ops): - summaries.add_gan_model_image_summaries(get_model_fn(), grid_size=2) + def _test_add_gan_model_image_summaries_impl( + self, get_model_fn, expected_num_summary_ops, model_summaries): + summaries.add_gan_model_image_summaries( + get_model_fn(), grid_size=2, model_summaries=model_summaries) self.assertEquals(expected_num_summary_ops, len(ops.get_collection(ops.GraphKeys.SUMMARIES))) @@ -82,10 +83,13 @@ class SummariesTest(test.TestCase): summary.merge_all().eval() def test_add_gan_model_image_summaries(self): - self._test_add_gan_model_image_summaries_impl(get_gan_model, 5) + self._test_add_gan_model_image_summaries_impl(get_gan_model, 5, True) + + def test_add_gan_model_image_summaries_no_model(self): + self._test_add_gan_model_image_summaries_impl(get_gan_model, 2, False) def test_add_gan_model_image_summaries_for_cyclegan(self): - self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10) + self._test_add_gan_model_image_summaries_impl(get_cyclegan_model, 10, True) def _test_add_gan_model_summaries_impl(self, get_model_fn, expected_num_summary_ops): |