aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/gan
diff options
context:
space:
mode:
authorGravatar Surya Bhupatiraju <sbhupatiraju@google.com>2018-03-19 18:51:06 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-03-19 18:55:52 -0700
commit48adc7ba73177f2a9331918b160bc3d0775985b8 (patch)
tree903dbf829148c8d4da04c5d460a7ed8c72ba74db /tensorflow/contrib/gan
parent28a6a8b235dafd6610e95dc05676d5b64fa5a404 (diff)
Make L2 norm computation more stable.
Avoids the potentially numerically instable square root in the linalg_ops.norm() function because we 'undo' that operation with a math_ops.square() operation anyway. PiperOrigin-RevId: 189677716
Diffstat (limited to 'tensorflow/contrib/gan')
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index 323cbe6e76..7e86d10b64 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -563,7 +563,8 @@ def mean_only_frechet_classifier_distance_from_activations(
m_w = math_ops.reduce_mean(generated_activations, 0)
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
mofid = mean
if activations_dtype != dtypes.float64:
mofid = math_ops.cast(mofid, activations_dtype)
@@ -637,7 +638,8 @@ def diagonal_only_frechet_classifier_distance_from_activations(
(var + var_w) - 2.0 * math_ops.sqrt(math_ops.multiply(var, var_w)))
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
dofid = trace + mean
if activations_dtype != dtypes.float64:
dofid = math_ops.cast(dofid, activations_dtype)
@@ -718,7 +720,8 @@ def frechet_classifier_distance_from_activations(real_activations,
trace = math_ops.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
# Next the distance between means.
- mean = math_ops.square(linalg_ops.norm(m - m_w)) # This uses the L2 norm.
+ mean = math_ops.reduce_sum(
+ math_ops.squared_difference(m, m_w)) # Equivalent to L2 but more stable.
fid = trace + mean
if activations_dtype != dtypes.float64:
fid = math_ops.cast(fid, activations_dtype)