aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar apantykhin <apantykhin@gmail.com>2018-06-06 20:08:01 +0400
committerGravatar apantykhin <apantykhin@gmail.com>2018-06-06 20:08:01 +0400
commitef98fc4fb98f7df05b636d022297e2a708a7986b (patch)
treea2eec28e425e1667c1c36c7743d62131d1dd7983 /tensorflow/contrib/gan
parent558cbd9fc89055f532a9558a276a9e6b438371cf (diff)
parentb98adf4f09632d5f46d82d622d6627aed310541b (diff)
Merge branch 'ganhead_constructor_validate' of https://github.com/alexpantyukhin/tensorflow into ganhead_constructor_validate
# Conflicts: # tensorflow/contrib/gan/python/estimator/python/head_impl.py
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py7
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py11
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_impl.py45
-rw-r--r--tensorflow/contrib/gan/python/estimator/python/head_test.py7
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py30
-rw-r--r--tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py2
-rw-r--r--tensorflow/contrib/gan/python/features/python/conditioning_utils.py2
-rw-r--r--tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py6
-rw-r--r--tensorflow/contrib/gan/python/losses/python/losses_impl_test.py2
-rw-r--r--tensorflow/contrib/gan/python/train.py5
10 files changed, 85 insertions, 32 deletions
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
index e3fc6bf0f0..4092b32004 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py
@@ -112,6 +112,7 @@ class GANEstimator(estimator.Estimator):
generator_optimizer=None,
discriminator_optimizer=None,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
@@ -146,6 +147,9 @@ class GANEstimator(estimator.Estimator):
list of hooks. These hooks are run on the generator and discriminator
train ops, and can be used to implement the GAN training scheme.
Defaults to `train.get_sequential_train_hooks()`.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
@@ -160,7 +164,8 @@ class GANEstimator(estimator.Estimator):
else discriminator_optimizer)
gan_head = head_lib.gan_head(
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
- use_loss_summaries, get_hooks_fn=get_hooks_fn)
+ use_loss_summaries, get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn)
return _gan_model_fn(
features, labels, mode, generator_fn, discriminator_fn, gan_head,
add_summaries)
diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
index 387a62bd74..955482599b 100644
--- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import metrics as metrics_lib
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
@@ -194,6 +195,12 @@ class GANEstimatorIntegrationTest(test.TestCase):
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
+ def get_metrics(gan_model):
+ return {
+ 'mse_custom_metric': metrics_lib.mean_squared_error(
+ gan_model.real_data, gan_model.generated_data)
+ }
+
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
@@ -203,6 +210,7 @@ class GANEstimatorIntegrationTest(test.TestCase):
discriminator_loss_fn=losses.wasserstein_discriminator_loss,
generator_optimizer=gopt,
discriminator_optimizer=dopt,
+ get_eval_metric_ops_fn=get_metrics,
model_dir=self._model_dir)
# TRAIN
@@ -213,6 +221,9 @@ class GANEstimatorIntegrationTest(test.TestCase):
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
+ self.assertEqual(scores['discriminator_loss'] + scores['generator_loss'],
+ scores['loss'])
+ self.assertIn('mse_custom_metric', six.iterkeys(scores))
# PREDICT
predictions = np.array([x for x in est.predict(predict_input_fn)])
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
index 4750f94d9a..513e451e7b 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_impl.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py
@@ -25,17 +25,21 @@ from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
from tensorflow.python.framework import ops
+from tensorflow.python.ops import metrics as metrics_lib
__all__ = [
'GANHead',
'gan_head',
]
+def _summary_key(head_name, val):
+ return '%s/%s' % (val, head_name) if head_name else val
+
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
- name=None):
+ get_eval_metric_ops_fn=None, name=None):
"""Creates a `GANHead`.
Args:
@@ -47,9 +51,12 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
- If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks.
+ If `None`, uses defaults.
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks.
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
@@ -62,6 +69,7 @@ def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
use_loss_summaries=use_loss_summaries,
get_hooks_fn=get_hooks_fn,
+ get_eval_metric_ops_fn=get_eval_metric_ops_fn,
name=name)
@@ -72,6 +80,7 @@ class GANHead(head._Head): # pylint: disable=protected-access
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=None,
+ get_eval_metric_ops_fn=None,
name=None):
"""`Head` for GAN training.
@@ -85,8 +94,11 @@ class GANHead(head._Head): # pylint: disable=protected-access
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
- get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
- of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_hooks_fn: A function that takes a `GANTrainOps` tuple and returns a
+ list of hooks. Defaults to `train.get_sequential_train_hooks()`
+ get_eval_metric_ops_fn: A function that takes a `GANModel`, and returns a
+ dict of metric results keyed by name. The output of this function is
+ passed into `tf.estimator.EstimatorSpec` during evaluation.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
@@ -115,6 +127,8 @@ class GANHead(head._Head): # pylint: disable=protected-access
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
self._get_hooks_fn = get_hooks_fn
+ self._get_eval_metric_ops_fn = get_eval_metric_ops_fn
+ self._name = name
@property
def name(self):
@@ -184,13 +198,26 @@ class GANHead(head._Head): # pylint: disable=protected-access
gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None)
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
+ with ops.name_scope(None, 'metrics',
+ [gan_loss.generator_loss,
+ gan_loss.discriminator_loss]):
+ eval_metric_ops = {
+ _summary_key(self._name, 'generator_loss'):
+ metrics_lib.mean(gan_loss.generator_loss),
+ _summary_key(self._name, 'discriminator_loss'):
+ metrics_lib.mean(gan_loss.discriminator_loss)
+ }
+ if self._get_eval_metric_ops_fn is not None:
+ custom_eval_metric_ops = self._get_eval_metric_ops_fn(gan_model)
+ if not isinstance(custom_eval_metric_ops, dict):
+ raise TypeError('get_eval_metric_ops_fn must return a dict, '
+ 'received: {}'.format(custom_eval_metric_ops))
+ eval_metric_ops.update(custom_eval_metric_ops)
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
loss=scalar_loss,
- # TODO(joelshor): Add metrics. If head name provided, append it to
- # metric keys.
- eval_metric_ops={})
+ eval_metric_ops=eval_metric_ops)
elif mode == model_fn_lib.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError('train_op_fn can not be None.')
diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py
index 8168f005cd..6587f1fc60 100644
--- a/tensorflow/contrib/gan/python/estimator/python/head_test.py
+++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py
@@ -62,9 +62,14 @@ class GANHeadTest(test.TestCase):
generator_loss_fn=dummy_loss,
discriminator_loss_fn=dummy_loss,
generator_optimizer=training.GradientDescentOptimizer(1.0),
- discriminator_optimizer=training.GradientDescentOptimizer(1.0))
+ discriminator_optimizer=training.GradientDescentOptimizer(1.0),
+ get_eval_metric_ops_fn=self.get_metrics)
self.assertTrue(isinstance(self.gan_head, head.GANHead))
+ def get_metrics(self, gan_model):
+ self.assertTrue(isinstance(gan_model, tfgan_tuples.GANModel))
+ return {}
+
def _test_modes_helper(self, mode):
self.gan_head.create_estimator_spec(
features=None,
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index 47e51415fd..d914f54945 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -488,25 +488,25 @@ def frechet_classifier_distance(real_images,
The Frechet Inception distance. A floating-point scalar of the same type
as the output of `classifier_fn`.
"""
-
real_images_list = array_ops.split(
real_images, num_or_size_splits=num_batches)
generated_images_list = array_ops.split(
generated_images, num_or_size_splits=num_batches)
- imgs = array_ops.stack(real_images_list + generated_images_list)
+ real_imgs = array_ops.stack(real_images_list)
+ generated_imgs = array_ops.stack(generated_images_list)
# Compute the activations using the memory-efficient `map_fn`.
- activations = functional_ops.map_fn(
- fn=classifier_fn,
- elems=imgs,
- parallel_iterations=1,
- back_prop=False,
- swap_memory=True,
- name='RunClassifier')
+ def compute_activations(elems):
+ return functional_ops.map_fn(fn=classifier_fn,
+ elems=elems,
+ parallel_iterations=1,
+ back_prop=False,
+ swap_memory=True,
+ name='RunClassifier')
- # Split the activations by the real and generated images.
- real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
+ real_a = compute_activations(real_imgs)
+ gen_a = compute_activations(generated_imgs)
# Ensure the activations have the right shapes.
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
@@ -697,18 +697,20 @@ def frechet_classifier_distance_from_activations(real_activations,
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_activations, 0)
m_w = math_ops.reduce_mean(generated_activations, 0)
- num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_generated = math_ops.to_double(
+ array_ops.shape(generated_activations)[0])
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
real_centered = real_activations - m
sigma = math_ops.matmul(
real_centered, real_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_real - 1)
gen_centered = generated_activations - m_w
sigma_w = math_ops.matmul(
gen_centered, gen_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_generated - 1)
# Find the Tr(sqrt(sigma sigma_w)) component of FID
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
diff --git a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
index 4b10bc0f8e..4b1105f6bd 100644
--- a/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/sliced_wasserstein_impl.py
@@ -161,7 +161,7 @@ def _sliced_wasserstein(a, b, random_sampling_count, random_projection_dim):
proj = random_ops.random_normal(
[array_ops.shape(a)[1], random_projection_dim])
proj *= math_ops.rsqrt(
- math_ops.reduce_sum(math_ops.square(proj), 0, keep_dims=True))
+ math_ops.reduce_sum(math_ops.square(proj), 0, keepdims=True))
# Project both distributions and sort them.
proj_a = math_ops.matmul(a, proj)
proj_b = math_ops.matmul(b, proj)
diff --git a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py
index df71187fbd..a9b8faa712 100644
--- a/tensorflow/contrib/gan/python/features/python/conditioning_utils.py
+++ b/tensorflow/contrib/gan/python/features/python/conditioning_utils.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Miscellanous utilities for TFGAN code and examples."""
+"""Miscellaneous utilities for TFGAN code and examples."""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
index f8b372546b..650eab97a3 100644
--- a/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
+++ b/tensorflow/contrib/gan/python/features/python/virtual_batchnorm_impl.py
@@ -64,11 +64,11 @@ def _statistics(x, axes):
y = math_ops.cast(x, dtypes.float32) if x.dtype == dtypes.float16 else x
# Compute true mean while keeping the dims for proper broadcasting.
- shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keep_dims=True))
+ shift = array_ops.stop_gradient(math_ops.reduce_mean(y, axes, keepdims=True))
- shifted_mean = math_ops.reduce_mean(y - shift, axes, keep_dims=True)
+ shifted_mean = math_ops.reduce_mean(y - shift, axes, keepdims=True)
mean = shifted_mean + shift
- mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keep_dims=True)
+ mean_squared = math_ops.reduce_mean(math_ops.square(y), axes, keepdims=True)
mean = array_ops.squeeze(mean, axes)
mean_squared = array_ops.squeeze(mean_squared, axes)
diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
index 2889e93743..9f5fee4542 100644
--- a/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
+++ b/tensorflow/contrib/gan/python/losses/python/losses_impl_test.py
@@ -570,7 +570,7 @@ class MutualInformationPenaltyTest(test.TestCase, _PenaltyTest):
'predicted_distributions': self._predicted_distributions,
}
self._expected_loss = 1.61610
- self._expected_op_name = 'mutual_information_loss/mul'
+ self._expected_op_name = 'mutual_information_loss/mul_1'
self._batch_size = 2
diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py
index 73acd05b60..6fa43059f3 100644
--- a/tensorflow/contrib/gan/python/train.py
+++ b/tensorflow/contrib/gan/python/train.py
@@ -710,7 +710,10 @@ def gan_train_ops(
be used to train a generator/discriminator pair.
"""
if isinstance(model, namedtuples.CycleGANModel):
- saved_params = locals()
+ # Get and store all arguments other than model and loss from locals.
+ # Contents of locals should not be modified, may not affect values. So make
+ # a copy. https://docs.python.org/2/library/functions.html#locals.
+ saved_params = dict(locals())
saved_params.pop('model', None)
saved_params.pop('loss', None)
kwargs = saved_params.pop('kwargs', {})