aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-23 06:55:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-23 06:58:03 -0700
commita821ea02afd05a96dd0e118e6ee745d472c61b3e (patch)
treed906740338266711f6a016adaef3e6ab71e62c65 /tensorflow/contrib/gan
parent6d57bca02b3278e812658fe5514a2bcb17670dbe (diff)
Support non-equal set sizes for FID computation.
PiperOrigin-RevId: 193917167
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index 47e51415fd..d914f54945 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -488,25 +488,25 @@ def frechet_classifier_distance(real_images,
The Frechet Inception distance. A floating-point scalar of the same type
as the output of `classifier_fn`.
"""
-
real_images_list = array_ops.split(
real_images, num_or_size_splits=num_batches)
generated_images_list = array_ops.split(
generated_images, num_or_size_splits=num_batches)
- imgs = array_ops.stack(real_images_list + generated_images_list)
+ real_imgs = array_ops.stack(real_images_list)
+ generated_imgs = array_ops.stack(generated_images_list)
# Compute the activations using the memory-efficient `map_fn`.
- activations = functional_ops.map_fn(
- fn=classifier_fn,
- elems=imgs,
- parallel_iterations=1,
- back_prop=False,
- swap_memory=True,
- name='RunClassifier')
+ def compute_activations(elems):
+ return functional_ops.map_fn(fn=classifier_fn,
+ elems=elems,
+ parallel_iterations=1,
+ back_prop=False,
+ swap_memory=True,
+ name='RunClassifier')
- # Split the activations by the real and generated images.
- real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
+ real_a = compute_activations(real_imgs)
+ gen_a = compute_activations(generated_imgs)
# Ensure the activations have the right shapes.
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
@@ -697,18 +697,20 @@ def frechet_classifier_distance_from_activations(real_activations,
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_activations, 0)
m_w = math_ops.reduce_mean(generated_activations, 0)
- num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_generated = math_ops.to_double(
+ array_ops.shape(generated_activations)[0])
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
real_centered = real_activations - m
sigma = math_ops.matmul(
real_centered, real_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_real - 1)
gen_centered = generated_activations - m_w
sigma_w = math_ops.matmul(
gen_centered, gen_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_generated - 1)
# Find the Tr(sqrt(sigma sigma_w)) component of FID
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)