diff options
author | 2018-08-18 10:01:17 +0200 | |
---|---|---|
committer | 2018-08-18 10:01:17 +0200 | |
commit | c4858c15110286b1afd091c70ab4d99549b2e856 (patch) | |
tree | 4c5004622e37b63d3387fcada34ff6c16c1b0486 /tensorflow/contrib/gan | |
parent | ac805e5c29d9d3e4afcc6be11ad8888953c02159 (diff) |
[tfgan] Respect use_loss_summaries in GANEstimator
Since the refactor done in 47dea684efa41981e10299c2737317c504ce41af the `use_loss_summaries` argument of GANEstimator isn't respected anymore. This PR restores the original behavior and passes `use_loss_summaries` down to the loss functions.
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py | 10 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py | 2 |
2 files changed, 7 insertions, 5 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index 8e4affb9b4..3dd066a406 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -187,7 +187,7 @@ class GANEstimator(estimator.Estimator): return _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn) + get_hooks_fn, use_loss_summaries) super(GANEstimator, self).__init__( model_fn=_model_fn, model_dir=model_dir, config=config) @@ -214,15 +214,17 @@ def _get_gan_model( def _get_estimator_spec( mode, gan_model, generator_loss_fn, discriminator_loss_fn, get_eval_metric_ops_fn, generator_optimizer, discriminator_optimizer, - get_hooks_fn=None): + get_hooks_fn=None, use_loss_summaries=True): """Get the EstimatorSpec for the current mode.""" if mode == model_fn_lib.ModeKeys.PREDICT: estimator_spec = model_fn_lib.EstimatorSpec( mode=mode, predictions=gan_model.generated_data) else: gan_loss = tfgan_tuples.GANLoss( - generator_loss=generator_loss_fn(gan_model), - discriminator_loss=discriminator_loss_fn(gan_model)) + generator_loss=generator_loss_fn( + gan_model, add_summaries=use_loss_summaries), + discriminator_loss=discriminator_loss_fn( + gan_model, add_summaries=use_loss_summaries)) if mode == model_fn_lib.ModeKeys.EVAL: estimator_spec = _get_eval_estimator_spec( gan_model, gan_loss, get_eval_metric_ops_fn) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 9ac9c6ca9c..83f8dd641f 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -116,7 +116,7 @@ def get_dummy_gan_model(): discriminator_fn=None) -def dummy_loss_fn(gan_model): +def dummy_loss_fn(gan_model, add_summaries=True): return math_ops.reduce_sum(gan_model.discriminator_real_outputs - gan_model.discriminator_gen_outputs) |