aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-12-12 16:30:09 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-12 16:33:19 -0800
commit618d5c5fad4f70456856625322db104b851a399d (patch)
treea2f2f2d28d4ea2011b35db4438de4475d7fb826a
parentc373a16f61bff835181163dc07417e3cba6f47bc (diff)
BUGFIX: MVN Full Covariance: Use dtype dependent tolerance to verify symmetric.
PiperOrigin-RevId: 178833453
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 8e69dadfb4..00a18569fc 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.distributions.python.ops import mvn_tril
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
__all__ = [
@@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
covariance_matrix = ops.convert_to_tensor(
covariance_matrix, name="covariance_matrix")
if validate_args:
- assert_symmetric = check_ops.assert_equal(
- covariance_matrix,
- array_ops.matrix_transpose(covariance_matrix),
+ tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10
+ diff = math_ops.abs(
+ covariance_matrix
+ - array_ops.matrix_transpose(covariance_matrix))
+ assert_symmetric = check_ops.assert_less(
+ diff, tol + tol * math_ops.abs(covariance_matrix),
message="Matrix was not symmetric.")
covariance_matrix = control_flow_ops.with_dependencies(
[assert_symmetric], covariance_matrix)