diff options
author | Surya Bhupatiraju <sbhupatiraju@google.com> | 2018-03-15 13:02:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-15 13:11:58 -0700 |
commit | 19e4a6ab8e400c90ccd0c95c6396a0d2cc925324 (patch) | |
tree | 4cd5b78fb16b28d5532674326ef6b12ccac02383 /tensorflow/contrib/gan | |
parent | 310c32e4f793b591983ec6cf9635d1d0a86602fa (diff) |
Add mean-only FID and diagonal-covariance-only FID variants to TFGAN.
PiperOrigin-RevId: 189232299
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py | 190 | ||||
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py | 60 |
2 files changed, 226 insertions, 24 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index fdfabd07c1..323cbe6e76 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -44,11 +44,11 @@ from tensorflow.python.ops import functional_ops from tensorflow.python.ops import image_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_impl from tensorflow.python.ops import nn_ops from tensorflow.python.platform import gfile from tensorflow.python.platform import resource_loader - __all__ = [ 'get_graph_def_from_disk', 'get_graph_def_from_resource', @@ -62,10 +62,11 @@ __all__ = [ 'frechet_inception_distance', 'frechet_classifier_distance', 'frechet_classifier_distance_from_activations', + 'mean_only_frechet_classifier_distance_from_activations', + 'diagonal_only_frechet_classifier_distance_from_activations', 'INCEPTION_DEFAULT_IMAGE_SIZE', ] - INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' INCEPTION_INPUT = 'Mul:0' @@ -77,8 +78,7 @@ INCEPTION_DEFAULT_IMAGE_SIZE = 299 def _validate_images(images, image_size): images = ops.convert_to_tensor(images) images.shape.with_rank(4) - images.shape.assert_is_compatible_with( - [None, image_size, image_size, None]) + images.shape.assert_is_compatible_with([None, image_size, image_size, None]) return images @@ -109,9 +109,10 @@ def _symmetric_matrix_square_root(mat, eps=1e-10): math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) -def preprocess_image( - images, height=INCEPTION_DEFAULT_IMAGE_SIZE, - width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): +def preprocess_image(images, + height=INCEPTION_DEFAULT_IMAGE_SIZE, + width=INCEPTION_DEFAULT_IMAGE_SIZE, + scope=None): """Prepare a batch of images for evaluation. This is the preprocessing portion of the graph from @@ -272,8 +273,11 @@ def run_inception(images, return activations -def run_image_classifier(tensor, graph_def, input_tensor, - output_tensor, scope='RunClassifier'): +def run_image_classifier(tensor, + graph_def, + input_tensor, + output_tensor, + scope='RunClassifier'): """Runs a network from a frozen graph. Args: @@ -433,8 +437,8 @@ def trace_sqrt_product(sigma, sigma_v): sqrt_sigma = _symmetric_matrix_square_root(sigma) # This is sqrt(A sigma_v A) above - sqrt_a_sigmav_a = math_ops.matmul( - sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) + sqrt_a_sigmav_a = math_ops.matmul(sqrt_sigma, + math_ops.matmul(sigma_v, sqrt_sigma)) return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) @@ -452,7 +456,7 @@ def frechet_classifier_distance(real_images, Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the @@ -511,10 +515,140 @@ def frechet_classifier_distance(real_images, return frechet_classifier_distance_from_activations(real_a, gen_a) -def frechet_classifier_distance_from_activations( +def mean_only_frechet_classifier_distance_from_activations( real_activations, generated_activations): """Classifier distance for evaluating a generative model from activations. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + In this variant, we only compute the difference between the means of the + fitted Gaussians. The computation leads to O(n) vs. O(n^2) memory usage, yet + still retains much of the same information as FID. + + Args: + real_activations: 2D array of activations of real images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + generated_activations: 2D array of activations of generated images of size + [num_images, num_dims] to use to compute Frechet Inception distance. + + Returns: + The mean-only Frechet Inception distance. A floating-point scalar of the + same type as the output of the activations. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute means of activations. + m = math_ops.reduce_mean(real_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) + + # Next the distance between means. + mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mofid = mean + if activations_dtype != dtypes.float64: + mofid = math_ops.cast(mofid, activations_dtype) + + return mofid + + +def diagonal_only_frechet_classifier_distance_from_activations( + real_activations, generated_activations): + """Classifier distance for evaluating a generative model. + + This is based on the Frechet Inception distance, but for an arbitrary + classifier. + + This technique is described in detail in https://arxiv.org/abs/1706.08500. + Given two Gaussian distribution with means m and m_w and covariance matrices + C and C_w, this function calcuates + + |m - m_w|^2 + (sigma + sigma_w - 2(sigma x sigma_w)^(1/2)) + + which captures how different the distributions of real images and generated + images (or more accurately, their visual features) are. Note that unlike the + Inception score, this is a true distance and utilizes information about real + world images. In this variant, we compute diagonal-only covariance matrices. + As a result, instead of computing an expensive matrix square root, we can do + something much simpler, and has O(n) vs O(n^2) space complexity. + + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + + Args: + real_activations: Real images to use to compute Frechet Inception distance. + generated_activations: Generated images to use to compute Frechet Inception + distance. + + Returns: + The diagonal-only Frechet Inception distance. A floating-point scalar of + the same type as the output of the activations. + + Raises: + ValueError: If the shape of the variance and mean vectors are not equal. + """ + real_activations.shape.assert_has_rank(2) + generated_activations.shape.assert_has_rank(2) + + activations_dtype = real_activations.dtype + if activations_dtype != dtypes.float64: + real_activations = math_ops.to_double(real_activations) + generated_activations = math_ops.to_double(generated_activations) + + # Compute mean and covariance matrices of activations. + m, var = nn_impl.moments(real_activations, axes=[0]) + m_w, var_w = nn_impl.moments(generated_activations, axes=[0]) + + actual_shape = var.get_shape() + expected_shape = m.get_shape() + + if actual_shape != expected_shape: + raise ValueError('shape: {} must match expected shape: {}'.format( + actual_shape, expected_shape)) + + # Compute the two components of FID. + + # First the covariance component. + # Here, note that trace(A + B) = trace(A) + trace(B) + trace = math_ops.reduce_sum( + (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) + + # Next the distance between means. + mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + dofid = trace + mean + if activations_dtype != dtypes.float64: + dofid = math_ops.cast(dofid, activations_dtype) + + return dofid + + +def frechet_classifier_distance_from_activations(real_activations, + generated_activations): + """Classifier distance for evaluating a generative model. + This methods computes the Frechet classifier distance from activations of real images and generated images. This can be used independently of the frechet_classifier_distance() method, especially in the case of using large @@ -525,13 +659,20 @@ def frechet_classifier_distance_from_activations( Given two Gaussian distribution with means m and m_w and covariance matrices C and C_w, this function calcuates - |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) + |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) which captures how different the distributions of real images and generated images (or more accurately, their visual features) are. Note that unlike the Inception score, this is a true distance and utilizes information about real world images. + Note that when computed using sample means and sample covariance matrices, + Frechet distance is biased. It is more biased for small sample sizes. (e.g. + even if the two distributions are the same, for a small sample size, the + expected Frechet distance is large). It is important to use the same + sample size to compute frechet classifier distance when comparing two + generative models. + Args: real_activations: 2D Tensor containing activations of real data. Shape is [batch_size, activation_size]. @@ -553,36 +694,37 @@ def frechet_classifier_distance_from_activations( # Compute mean and covariance matrices of activations. m = math_ops.reduce_mean(real_activations, 0) - m_v = math_ops.reduce_mean(generated_activations, 0) + m_w = math_ops.reduce_mean(generated_activations, 0) num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T real_centered = real_activations - m sigma = math_ops.matmul( - real_centered, real_centered, transpose_a=True) / (num_examples - 1) + real_centered, real_centered, transpose_a=True) / ( + num_examples - 1) - gen_centered = generated_activations - m_v - sigma_v = math_ops.matmul( - gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) + gen_centered = generated_activations - m_w + sigma_w = math_ops.matmul( + gen_centered, gen_centered, transpose_a=True) / ( + num_examples - 1) - # Find the Tr(sqrt(sigma sigma_v)) component of FID - sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) + # Find the Tr(sqrt(sigma sigma_w)) component of FID + sqrt_trace_component = trace_sqrt_product(sigma, sigma_w) # Compute the two components of FID. # First the covariance component. # Here, note that trace(A + B) = trace(A) + trace(B) - trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component + trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. + mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) return fid - frechet_inception_distance = functools.partial( frechet_classifier_distance, classifier_fn=functools.partial( diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py index 61dc8646dd..663e49bdca 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_test.py @@ -50,6 +50,26 @@ def _expected_inception_score(logits): return np.exp(np.mean(per_example_logincscore)) +def _expected_mean_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + mean = np.square(m - m_v).sum() + mofid = mean + return mofid + + +def _expected_diagonal_only_fid(real_imgs, gen_imgs): + m = np.mean(real_imgs, axis=0) + m_v = np.mean(gen_imgs, axis=0) + var = np.var(real_imgs, axis=0) + var_v = np.var(gen_imgs, axis=0) + sqcc = np.sqrt(var * var_v) + mean = (np.square(m - m_v)).sum() + trace = (var + var_v - 2 * sqcc).sum() + dofid = mean + trace + return dofid + + def _expected_fid(real_imgs, gen_imgs): m = np.mean(real_imgs, axis=0) m_v = np.mean(gen_imgs, axis=0) @@ -285,6 +305,46 @@ class ClassifierMetricsTest(test.TestCase): self.assertAllClose(_expected_inception_score(logits), incscore_np) + def test_mean_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + mofid_op = classifier_metrics.mean_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_mofid = sess.run(mofid_op) + + expected_mofid = _expected_mean_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_mofid, actual_mofid, 0.0001) + + def test_diagonal_only_frechet_classifier_distance_value(self): + """Test that `frechet_classifier_distance` gives the correct value.""" + np.random.seed(0) + + pool_real_a = np.float32(np.random.randn(256, 2048)) + pool_gen_a = np.float32(np.random.randn(256, 2048)) + + tf_pool_real_a = array_ops.constant(pool_real_a) + tf_pool_gen_a = array_ops.constant(pool_gen_a) + + dofid_op = classifier_metrics.diagonal_only_frechet_classifier_distance_from_activations( # pylint: disable=line-too-long + tf_pool_real_a, tf_pool_gen_a) + + with self.test_session() as sess: + actual_dofid = sess.run(dofid_op) + + expected_dofid = _expected_diagonal_only_fid(pool_real_a, pool_gen_a) + + self.assertAllClose(expected_dofid, actual_dofid, 0.0001) + def test_frechet_classifier_distance_value(self): """Test that `frechet_classifier_distance` gives the correct value.""" np.random.seed(0) |