aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Surya Bhupatiraju <sbhupatiraju@google.com>2018-03-15 13:02:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-15 13:11:58 -0700
commit19e4a6ab8e400c90ccd0c95c6396a0d2cc925324 (patch)
tree4cd5b78fb16b28d5532674326ef6b12ccac02383 /tensorflow/contrib/gan
parent310c32e4f793b591983ec6cf9635d1d0a86602fa (diff)
Add mean-only FID and diagonal-covariance-only FID variants to TFGAN.
PiperOrigin-RevId: 189232299
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py190
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py60
2 files changed, 226 insertions, 24 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index fdfabd07c1..323cbe6e76 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.platform import gfile
from tensorflow.python.platform import resource_loader
-
__all__ = [
'get_graph_def_from_disk',
'get_graph_def_from_resource',
@@ -62,10 +62,11 @@ __all__ = [
'frechet_inception_distance',
'frechet_classifier_distance',
'frechet_classifier_distance_from_activations',
+ 'mean_only_frechet_classifier_distance_from_activations',
+ 'diagonal_only_frechet_classifier_distance_from_activations',
'INCEPTION_DEFAULT_IMAGE_SIZE',
]
-
INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
INCEPTION_INPUT = 'Mul:0'
@@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299
def _validate_images(images, image_size):
images = ops.convert_to_tensor(images)
images.shape.with_rank(4)
- images.shape.assert_is_compatible_with(
- [None, image_size, image_size, None])
+ images.shape.assert_is_compatible_with([None, image_size, image_size, None])
return images
@@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10):
math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
-def preprocess_image(
- images, height=INCEPTION_DEFAULT_IMAGE_SIZE,
- width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None):
+def preprocess_image(images,
+ height=INCEPTION_DEFAULT_IMAGE_SIZE,
+ width=INCEPTION_DEFAULT_IMAGE_SIZE,
+ scope=None):
"""Prepare a batch of images for evaluation.
This is the preprocessing portion of the graph from
@@ -272,8 +273,11 @@ def run_inception(images,
return activations
-def run_image_classifier(tensor, graph_def, input_tensor,
- output_tensor, scope='RunClassifier'):
+def run_image_classifier(tensor,
+ graph_def,
+ input_tensor,
+ output_tensor,
+ scope='RunClassifier'):
"""Runs a network from a frozen graph.
Args:
@@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v):
sqrt_sigma = _symmetric_matrix_square_root(sigma)
# This is sqrt(A sigma_v A) above
- sqrt_a_sigmav_a = math_ops.matmul(
- sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma))
+ sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma,
+ math_ops.matmul(sigma_v, sqrt_sigma))
return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
@@ -452,7 +456,7 @@ def frechet_classifier_distance(real_images,
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
- |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
+ |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
@@ -511,10 +515,140 @@ def frechet_classifier_distance(real_images,
return frechet_classifier_distance_from_activations(real_a, gen_a)
-def frechet_classifier_distance_from_activations(
+def mean_only_frechet_classifier_distance_from_activations(
real_activations, generated_activations):
"""Classifier distance for evaluating a generative model from activations.
+ Given two Gaussian distribution with means m and m_w and covariance matrices
+ C and C_w, this function calcuates
+
+ |m - m_w|^2
+
+ which captures how different the distributions of real images and generated
+ images (or more accurately, their visual features) are. Note that unlike the
+ Inception score, this is a true distance and utilizes information about real
+ world images.
+
+ Note that when computed using sample means and sample covariance matrices,
+ Frechet distance is biased. It is more biased for small sample sizes. (e.g.
+ even if the two distributions are the same, for a small sample size, the
+ expected Frechet distance is large). It is important to use the same
+ sample size to compute frechet classifier distance when comparing two
+ generative models.
+
+ In this variant, we only compute the difference between the means of the
+ fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet
+ still retains much of the same information as FID.
+
+ Args:
+ real_activations: 2D array of activations of real images of size
+ [num_images, num_dims] to use to compute Frechet Inception distance.
+ generated_activations: 2D array of activations of generated images of size
+ [num_images, num_dims] to use to compute Frechet Inception distance.
+
+ Returns:
+ The mean-only Frechet Inception distance. A floating-point scalar of the
+ same type as the output of the activations.
+ """
+ real_activations.shape.assert_has_rank(2)
+ generated_activations.shape.assert_has_rank(2)
+
+ activations_dtype = real_activations.dtype
+ if activations_dtype != dtypes.float64:
+ real_activations = math_ops.to_double(real_activations)
+ generated_activations = math_ops.to_double(generated_activations)
+
+ # Compute means of activations.
+ m = math_ops.reduce_mean(real_activations, 0)
+ m_w = math_ops.reduce_mean(generated_activations, 0)
+
+ # Next the distance between means.
+ mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mofid = mean
+ if activations_dtype != dtypes.float64:
+ mofid = math_ops.cast(mofid, activations_dtype)
+
+ return mofid
+
+
+def diagonal_only_frechet_classifier_distance_from_activations(
+ real_activations, generated_activations):
+ """Classifier distance for evaluating a generative model.
+
+ This is based on the Frechet Inception distance, but for an arbitrary
+ classifier.
+
+ This technique is described in detail in https://arxiv.org/abs/1706.08500.
+ Given two Gaussian distribution with means m and m_w and covariance matrices
+ C and C_w, this function calcuates
+
+ |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2))
+
+ which captures how different the distributions of real images and generated
+ images (or more accurately, their visual features) are. Note that unlike the
+ Inception score, this is a true distance and utilizes information about real
+ world images. In this variant, we compute diagonal-only covariance matrices.
+ As a result, instead of computing an expensive matrix square root, we can do
+ something much simpler, and has O(n) vs O(n^2) space complexity.
+
+ Note that when computed using sample means and sample covariance matrices,
+ Frechet distance is biased. It is more biased for small sample sizes. (e.g.
+ even if the two distributions are the same, for a small sample size, the
+ expected Frechet distance is large). It is important to use the same
+ sample size to compute frechet classifier distance when comparing two
+ generative models.
+
+ Args:
+ real_activations: Real images to use to compute Frechet Inception distance.
+ generated_activations: Generated images to use to compute Frechet Inception
+ distance.
+
+ Returns:
+ The diagonal-only Frechet Inception distance. A floating-point scalar of
+ the same type as the output of the activations.
+
+ Raises:
+ ValueError: If the shape of the variance and mean vectors are not equal.
+ """
+ real_activations.shape.assert_has_rank(2)
+ generated_activations.shape.assert_has_rank(2)
+
+ activations_dtype = real_activations.dtype
+ if activations_dtype != dtypes.float64:
+ real_activations = math_ops.to_double(real_activations)
+ generated_activations = math_ops.to_double(generated_activations)
+
+ # Compute mean and covariance matrices of activations.
+ m, var = nn_impl.moments(real_activations, axes=[0])
+ m_w, var_w = nn_impl.moments(generated_activations, axes=[0])
+
+ actual_shape = var.get_shape()
+ expected_shape = m.get_shape()
+
+ if actual_shape != expected_shape:
+ raise ValueError('shape: {} must match expected shape: {}'.format(
+ actual_shape, expected_shape))
+
+ # Compute the two components of FID.
+
+ # First the covariance component.
+ # Here, note that trace(A + B) = trace(A) + trace(B)
+ trace = math_ops.reduce_sum(
+ (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w)))
+
+ # Next the distance between means.
+ mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ dofid = trace + mean
+ if activations_dtype != dtypes.float64:
+ dofid = math_ops.cast(dofid, activations_dtype)
+
+ return dofid
+
+
+def frechet_classifier_distance_from_activations(real_activations,
+ generated_activations):
+ """Classifier distance for evaluating a generative model.
+
This methods computes the Frechet classifier distance from activations of
real images and generated images. This can be used independently of the
frechet_classifier_distance() method, especially in the case of using large
@@ -525,13 +659,20 @@ def frechet_classifier_distance_from_activations(
Given two Gaussian distribution with means m and m_w and covariance matrices
C and C_w, this function calcuates
- |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
+ |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
which captures how different the distributions of real images and generated
images (or more accurately, their visual features) are. Note that unlike the
Inception score, this is a true distance and utilizes information about real
world images.
+ Note that when computed using sample means and sample covariance matrices,
+ Frechet distance is biased. It is more biased for small sample sizes. (e.g.
+ even if the two distributions are the same, for a small sample size, the
+ expected Frechet distance is large). It is important to use the same
+ sample size to compute frechet classifier distance when comparing two
+ generative models.
+
Args:
real_activations: 2D Tensor containing activations of real data. Shape is
[batch_size, activation_size].
@@ -553,36 +694,37 @@ def frechet_classifier_distance_from_activations(
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_activations, 0)
- m_v = math_ops.reduce_mean(generated_activations, 0)
+ m_w = math_ops.reduce_mean(generated_activations, 0)
num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
real_centered = real_activations - m
sigma = math_ops.matmul(
- real_centered, real_centered, transpose_a=True) / (num_examples - 1)
+ real_centered, real_centered, transpose_a=True) / (
+ num_examples - 1)
- gen_centered = generated_activations - m_v
- sigma_v = math_ops.matmul(
- gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)
+ gen_centered = generated_activations - m_w
+ sigma_w = math_ops.matmul(
+ gen_centered, gen_centered, transpose_a=True) / (
+ num_examples - 1)
- # Find the Tr(sqrt(sigma sigma_v)) component of FID
- sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)
+ # Find the Tr(sqrt(sigma sigma_w)) component of FID
+ sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
# Compute the two components of FID.
# First the covariance component.
# Here, note that trace(A + B) = trace(A) + trace(B)
- trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component
+ trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm.
+ mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
fid = trace + mean
if activations_dtype != dtypes.float64:
fid = math_ops.cast(fid, activations_dtype)
return fid
-
frechet_inception_distance = functools.partial(
frechet_classifier_distance,
classifier_fn=functools.partial(
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
index 61dc8646dd..663e49bdca 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py
@@ -50,6 +50,26 @@ def _expected_inception_score(logits):
return np.exp(np.mean(per_example_logincscore))
+def _expected_mean_only_fid(real_imgs, gen_imgs):
+ m = np.mean(real_imgs, axis=0)
+ m_v = np.mean(gen_imgs, axis=0)
+ mean = np.square(m - m_v).sum()
+ mofid = mean
+ return mofid
+
+
+def _expected_diagonal_only_fid(real_imgs, gen_imgs):
+ m = np.mean(real_imgs, axis=0)
+ m_v = np.mean(gen_imgs, axis=0)
+ var = np.var(real_imgs, axis=0)
+ var_v = np.var(gen_imgs, axis=0)
+ sqcc = np.sqrt(var * var_v)
+ mean = (np.square(m - m_v)).sum()
+ trace = (var + var_v - 2 * sqcc).sum()
+ dofid = mean + trace
+ return dofid
+
+
def _expected_fid(real_imgs, gen_imgs):
m = np.mean(real_imgs, axis=0)
m_v = np.mean(gen_imgs, axis=0)
@@ -285,6 +305,46 @@ class ClassifierMetricsTest(test.TestCase):
self.assertAllClose(_expected_inception_score(logits), incscore_np)
+ def test_mean_only_frechet_classifier_distance_value(self):
+ """Test that `frechet_classifier_distance` gives the correct value."""
+ np.random.seed(0)
+
+ pool_real_a = np.float32(np.random.randn(256, 2048))
+ pool_gen_a = np.float32(np.random.randn(256, 2048))
+
+ tf_pool_real_a = array_ops.constant(pool_real_a)
+ tf_pool_gen_a = array_ops.constant(pool_gen_a)
+
+ mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
+ tf_pool_real_a, tf_pool_gen_a)
+
+ with self.test_session() as sess:
+ actual_mofid = sess.run(mofid_op)
+
+ expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a)
+
+ self.assertAllClose(expected_mofid, actual_mofid, 0.0001)
+
+ def test_diagonal_only_frechet_classifier_distance_value(self):
+ """Test that `frechet_classifier_distance` gives the correct value."""
+ np.random.seed(0)
+
+ pool_real_a = np.float32(np.random.randn(256, 2048))
+ pool_gen_a = np.float32(np.random.randn(256, 2048))
+
+ tf_pool_real_a = array_ops.constant(pool_real_a)
+ tf_pool_gen_a = array_ops.constant(pool_gen_a)
+
+ dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long
+ tf_pool_real_a, tf_pool_gen_a)
+
+ with self.test_session() as sess:
+ actual_dofid = sess.run(dofid_op)
+
+ expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a)
+
+ self.assertAllClose(expected_dofid, actual_dofid, 0.0001)
+
def test_frechet_classifier_distance_value(self):
"""Test that `frechet_classifier_distance` gives the correct value."""
np.random.seed(0)