aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py12
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index 8e69dadfb4..00a18569fc 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib.distributions.python.ops import mvn_tril
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
__all__ = [
@@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
covariance_matrix = ops.convert_to_tensor(
covariance_matrix, name="covariance_matrix")
if validate_args:
- assert_symmetric = check_ops.assert_equal(
- covariance_matrix,
- array_ops.matrix_transpose(covariance_matrix),
+ tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10
+ diff = math_ops.abs(
+ covariance_matrix
+ - array_ops.matrix_transpose(covariance_matrix))
+ assert_symmetric = check_ops.assert_less(
+ diff, tol + tol * math_ops.abs(covariance_matrix),
message="Matrix was not symmetric.")
covariance_matrix = control_flow_ops.with_dependencies(
[assert_symmetric], covariance_matrix)