diff options
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 8e69dadfb4..00a18569fc 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import mvn_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops __all__ = [ @@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), + tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10 + diff = math_ops.abs( + covariance_matrix + - array_ops.matrix_transpose(covariance_matrix)) + assert_symmetric = check_ops.assert_less( + diff, tol + tol * math_ops.abs(covariance_matrix), message="Matrix was not symmetric.") covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) |