diff options
author | Surya Bhupatiraju <sbhupatiraju@google.com> | 2018-03-19 18:51:06 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-03-19 18:55:52 -0700 |
commit | 48adc7ba73177f2a9331918b160bc3d0775985b8 (patch) | |
tree | 903dbf829148c8d4da04c5d460a7ed8c72ba74db /tensorflow/contrib/gan | |
parent | 28a6a8b235dafd6610e95dc05676d5b64fa5a404 (diff) |
Make L2 norm computation more stable.
Avoids the potentially numerically instable square root in the linalg_ops.norm() function because we 'undo' that operation with a math_ops.square() operation anyway.
PiperOrigin-RevId: 189677716
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r-- | tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py index 323cbe6e76..7e86d10b64 100644 --- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py +++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py @@ -563,7 +563,8 @@ def mean_only_frechet_classifier_distance_from_activations( m_w = math_ops.reduce_mean(generated_activations, 0) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. mofid = mean if activations_dtype != dtypes.float64: mofid = math_ops.cast(mofid, activations_dtype) @@ -637,7 +638,8 @@ def diagonal_only_frechet_classifier_distance_from_activations( (var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w))) # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. dofid = trace + mean if activations_dtype != dtypes.float64: dofid = math_ops.cast(dofid, activations_dtype) @@ -718,7 +720,8 @@ def frechet_classifier_distance_from_activations(real_activations, trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component # Next the distance between means. - mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm. + mean = math_ops.reduce_sum( + math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable. fid = trace + mean if activations_dtype != dtypes.float64: fid = math_ops.cast(fid, activations_dtype) |