diff options
author | 2017-12-12 16:30:09 -0800 | |
---|---|---|
committer | 2017-12-12 16:33:19 -0800 | |
commit | 618d5c5fad4f70456856625322db104b851a399d (patch) | |
tree | a2f2f2d28d4ea2011b35db4438de4475d7fb826a | |
parent | c373a16f61bff835181163dc07417e3cba6f47bc (diff) |
BUGFIX: MVN Full Covariance: Use dtype dependent tolerance to verify symmetric.
PiperOrigin-RevId: 178833453
-rw-r--r-- | tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py | 12 |
1 files changed, 9 insertions, 3 deletions
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py index 8e69dadfb4..00a18569fc 100644 --- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py +++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py @@ -18,12 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import numpy as np + from tensorflow.contrib.distributions.python.ops import mvn_tril from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops __all__ = [ @@ -167,9 +170,12 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL): covariance_matrix = ops.convert_to_tensor( covariance_matrix, name="covariance_matrix") if validate_args: - assert_symmetric = check_ops.assert_equal( - covariance_matrix, - array_ops.matrix_transpose(covariance_matrix), + tol = np.finfo(covariance_matrix.dtype.as_numpy_dtype).eps * 10 + diff = math_ops.abs( + covariance_matrix + - array_ops.matrix_transpose(covariance_matrix)) + assert_symmetric = check_ops.assert_less( + diff, tol + tol * math_ops.abs(covariance_matrix), message="Matrix was not symmetric.") covariance_matrix = control_flow_ops.with_dependencies( [assert_symmetric], covariance_matrix) |