aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lgeiger@users.noreply.github.com>2018-05-08 00:19:13 +0200
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-07 15:19:13 -0700
commit714f3c4f2f901e865bfcbf830485adafb92dca48 (patch)
tree4531b4fe02ff8a2170d54c0b4b118e91a39881b3 /tensorflow/contrib/gan
parentd8bda536c5080e761bcb24ab6984c26da875f52c (diff)
[tfgan] Add discriminator and generator losses to eval_metrics (#19117)
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py2
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py18
2 files changed, 17 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 387a62bd74..6bbd173f86 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -213,6 +213,8 @@ class GANEstimatorIntegrationTest(test.TestCase):
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
+ self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
+ scores['loss'])
# PREDICT
predictions = np.array([x for x in est.predict(predict_input_fn)])
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index a21358c50b..d174cb3bb2 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -25,12 +25,16 @@ from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
__all__ = [
'GANHead',
'gan_head',
]
+def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
@@ -104,6 +108,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
self._get_hooks_fn = get_hooks_fn
+ self._name = name
@property
def name(self):
@@ -173,13 +178,20 @@ class GANHead(head._Head): # pylint: disable=protected-access
gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None)
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss,
+ gan_loss.discriminator_loss]):
+ eval_metric_ops = {
+ _summary_key(self._name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(self._name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
loss=scalar_loss,
- # TODO(joelshor): Add metrics. If head name provided, append it to
- # metric keys.
- eval_metric_ops={})
+ eval_metric_ops=eval_metric_ops)
elif mode == model_fn_lib.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError('train_op_fn can not be None.')