aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 12:32:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-04 12:32:26 -0700
commit4f770c34a3ba9638c458862d44f5c3e5cd3e4074 (patch)
treeb10d31b78cc0d74684d0071a1577ed8554d2d407 /tensorflow/contrib/gan
parentf9b58d46499c79e01f55d9e16867a8aace667db8 (diff)
parentc4858c15110286b1afd091c70ab4d99549b2e856 (diff)
Merge pull request #21061 from lgeiger:tfgan-fix-summary
PiperOrigin-RevId: 211502883
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py10
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py2
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index ab9886580d..7243f150ce 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -184,7 +184,7 @@ class GANEstimator(estimator.Estimator):
return _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn)
+ get_hooks_fn, use_loss_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
@@ -211,15 +211,17 @@ def _get_gan_model(
def _get_estimator_spec(
mode, gan_model, generator_loss_fn, discriminator_loss_fn,
get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer,
- get_hooks_fn=None):
+ get_hooks_fn=None, use_loss_summaries=True):
"""Get the EstimatorSpec for the current mode."""
if mode == model_fn_lib.ModeKeys.PREDICT:
estimator_spec = model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
gan_loss = tfgan_tuples.GANLoss(
- generator_loss=generator_loss_fn(gan_model),
- discriminator_loss=discriminator_loss_fn(gan_model))
+ generator_loss=generator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries),
+ discriminator_loss=discriminator_loss_fn(
+ gan_model, add_summaries=use_loss_summaries))
if mode == model_fn_lib.ModeKeys.EVAL:
estimator_spec = _get_eval_estimator_spec(
gan_model, gan_loss, get_eval_metric_ops_fn)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 9ac9c6ca9c..83f8dd641f 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -116,7 +116,7 @@ def get_dummy_gan_model():
discriminator_fn=None)
-def dummy_loss_fn(gan_model):
+def dummy_loss_fn(gan_model, add_summaries=True):
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)