diff options
author | Lukas Geiger <lgeiger@users.noreply.github.com> | 2018-05-08 00:19:13 +0200 |
---|---|---|
committer | Rasmus Munk Larsen <rmlarsen@google.com> | 2018-05-07 15:19:13 -0700 |
commit | 714f3c4f2f901e865bfcbf830485adafb92dca48 (patch) | |
tree | 4531b4fe02ff8a2170d54c0b4b118e91a39881b3 /tensorflow/contrib/gan | |
parent | d8bda536c5080e761bcb24ab6984c26da875f52c (diff) |
[tfgan] Add discriminator and generator losses to eval_metrics (#19117)
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py | 2 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/estimator/python/head_impl.py | 18 |
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py index 387a62bd74..6bbd173f86 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -213,6 +213,8 @@ class GANEstimatorIntegrationTest(test.TestCase): scores = est.evaluate(eval_input_fn) self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) self.assertIn('loss', six.iterkeys(scores)) + self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'], + scores['loss']) # PREDICT predictions = np.array([x for x in est.predict(predict_input_fn)]) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py index a21358c50b..d174cb3bb2 100644 --- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -25,12 +25,16 @@ from tensorflow.contrib.gan.python import train as tfgan_train from tensorflow.python.estimator import model_fn as model_fn_lib from tensorflow.python.estimator.canned import head from tensorflow.python.framework import ops +from tensorflow.python.ops import metrics as metrics_lib __all__ = [ 'GANHead', 'gan_head', ] +def _summary_key(head_name, val): + return '%s/%s' % (val, head_name) if head_name else val + def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, discriminator_optimizer, use_loss_summaries=True, @@ -104,6 +108,7 @@ class GANHead(head._Head): # pylint: disable=protected-access self._generator_optimizer = generator_optimizer self._discriminator_optimizer = discriminator_optimizer self._get_hooks_fn = get_hooks_fn + self._name = name @property def name(self): @@ -173,13 +178,20 @@ class GANHead(head._Head): # pylint: disable=protected-access gan_loss = self.create_loss( features=None, mode=mode, logits=gan_model, labels=None) scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + with ops.name_scope(None, 'metrics', + [gan_loss.generator_loss, + gan_loss.discriminator_loss]): + eval_metric_ops = { + _summary_key(self._name, 'generator_loss'): + metrics_lib.mean(gan_loss.generator_loss), + _summary_key(self._name, 'discriminator_loss'): + metrics_lib.mean(gan_loss.discriminator_loss) + } return model_fn_lib.EstimatorSpec( mode=model_fn_lib.ModeKeys.EVAL, predictions=gan_model.generated_data, loss=scalar_loss, - # TODO(joelshor): Add metrics. If head name provided, append it to - # metric keys. - eval_metric_ops={}) + eval_metric_ops=eval_metric_ops) elif mode == model_fn_lib.ModeKeys.TRAIN: if train_op_fn is None: raise ValueError('train_op_fn can not be None.') |