aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Lukas Geiger <lgeiger@users.noreply.github.com>2018-05-09 01:31:39 +0200
committerGravatar Rasmus Munk Larsen <rmlarsen@google.com>2018-05-08 16:31:39 -0700
commitc0fb9413914d983cad2ea6bb4997033a1f0dd722 (patch)
treef57ce588d7b8a5288d0a1dbcab73bc8b8dd8d163 /tensorflow/contrib/gan
parent24d9492f07e8cba89ae94cf01a1bcae22fcf438b (diff)
[tfgan] Allow to add custom eval metrics to GANEstimator (#19133)
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py7
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py9
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py27
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py7
4 files changed, 42 insertions, 8 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index e3fc6bf0f0..4092b32004 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator):
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
@@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator):
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
@@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator):
else discriminator_optimizer)
gan_head = head_lib.gan_head(
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn)
+ use_loss_summaries, get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn)
return _gan_model_fn(
features, labels, mode, generator_fn, discriminator_fn, gan_head,
add_summaries)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 6bbd173f86..955482599b 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase):
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
+ def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
@@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
discriminator_loss_fn=losses.wasserstein_discriminator_loss,
generator_optimizer=gopt,
discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
model_dir=self._model_dir)
# TRAIN
@@ -215,6 +223,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
self.assertIn('loss', six.iterkeys(scores))
self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
# PREDICT
predictions = np.array([x for x in est.predict(predict_input_fn)])
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index d174cb3bb2..ff903a78cc 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -39,7 +39,7 @@ def _summary_key(head_name, val):
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
- name=None):
+ get_eval_metric_ops_fn=None, name=None):
"""Creates a `GANHead`.
Args:
@@ -51,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
- If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks.
+ If `None`, uses defaults.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
@@ -66,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
use_loss_summaries=use_loss_summaries,
get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn,
name=name)
@@ -76,6 +80,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
name=None):
"""`Head` for GAN training.
@@ -89,8 +94,11 @@ class GANHead(head._Head): # pylint: disable=protected-access
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
@@ -108,6 +116,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
self._get_hooks_fn = get_hooks_fn
+ self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
self._name = name
@property
@@ -187,6 +196,12 @@ class GANHead(head._Head): # pylint: disable=protected-access
_summary_key(self._name, 'discriminator_loss'):
metrics_lib.mean(gan_loss.discriminator_loss)
}
+ if self._get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 8168f005cd..6587f1fc60 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase):
generator_loss_fn=dummy_loss,
discriminator_loss_fn=dummy_loss,
generator_optimizer=training.GradientDescentOptimizer(1.0),
- discriminator_optimizer=training.GradientDescentOptimizer(1.0))
+ discriminator_optimizer=training.GradientDescentOptimizer(1.0),
+ get_eval_metric_ops_fn=self.get_metrics)
self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ def get_metrics(self, gan_model):
+ self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
+ return {}
+
def _test_modes_helper(self, mode):
self.gan_head.create_estimator_spec(
features=None,