aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2017-05-15 09:43:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-15 09:46:56 -0700
commit18727ef581297437e20d6df4a08b60e8b021f284 (patch)
treeda0933d097e915aa3f0ecc803c03d39e0ff186cd
parentec3035c70da74b16a1268a28a11e808840e99588 (diff)
.assert_positive_definite and .assert_non_singular default implementations
added to LinearOperator base class. Previously I was avoiding these because they are inefficient...however, the desire to have a consistent API is overriding this. PiperOrigin-RevId: 156064641
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py69
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator.py105
2 files changed, 166 insertions, 8 deletions
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py
index 12c299683a..528bc3ed12 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_full_matrix_test.py
@@ -17,12 +17,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.contrib import linalg as linalg_lib
from tensorflow.contrib.linalg.python.ops import linear_operator_test_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
linalg = linalg_lib
@@ -72,6 +75,44 @@ class SquareLinearOperatorFullMatrixTest(
# Auto-detected.
self.assertTrue(operator.is_square)
+ def test_assert_non_singular_raises_if_cond_too_big_but_finite(self):
+ with self.test_session():
+ tril = linear_operator_test_util.random_tril_matrix(
+ shape=(50, 50), dtype=np.float32)
+ diag = np.logspace(-2, 2, 50).astype(np.float32)
+ tril = array_ops.matrix_set_diag(tril, diag)
+ matrix = math_ops.matmul(tril, tril, transpose_b=True).eval()
+ operator = linalg.LinearOperatorFullMatrix(matrix)
+ with self.assertRaisesOpError("Singular matrix"):
+ # Ensure that we have finite condition number...just HUGE.
+ cond = np.linalg.cond(matrix)
+ self.assertTrue(np.isfinite(cond))
+ self.assertGreater(cond, 1e12)
+ operator.assert_non_singular().run()
+
+ def test_assert_non_singular_raises_if_cond_infinite(self):
+ with self.test_session():
+ matrix = [[1., 1.], [1., 1.]]
+ # We don't pass the is_self_adjoint hint here, which means we take the
+ # generic code path.
+ operator = linalg.LinearOperatorFullMatrix(matrix)
+ with self.assertRaisesOpError("Singular matrix"):
+ operator.assert_non_singular().run()
+
+ def test_assert_self_adjoint(self):
+ matrix = [[0., 1.], [0., 1.]]
+ operator = linalg.LinearOperatorFullMatrix(matrix)
+ with self.test_session():
+ with self.assertRaisesOpError("not equal to its adjoint"):
+ operator.assert_self_adjoint().run()
+
+ def test_assert_positive_definite(self):
+ matrix = [[1., 1.], [1., 1.]]
+ operator = linalg.LinearOperatorFullMatrix(matrix, is_self_adjoint=True)
+ with self.test_session():
+ with self.assertRaisesOpError("Cholesky decomposition was not success"):
+ operator.assert_positive_definite().run()
+
class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
@@ -136,6 +177,34 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
self.assertTrue(operator._can_use_cholesky)
self.assertTrue(operator.is_square)
+ def test_assert_non_singular(self):
+ matrix = [[1., 1.], [1., 1.]]
+ operator = linalg.LinearOperatorFullMatrix(
+ matrix, is_self_adjoint=True, is_positive_definite=True)
+ with self.test_session():
+ # Cholesky decomposition may fail, so the error is not specific to
+ # non-singular.
+ with self.assertRaisesOpError(""):
+ operator.assert_non_singular().run()
+
+ def test_assert_self_adjoint(self):
+ matrix = [[0., 1.], [0., 1.]]
+ operator = linalg.LinearOperatorFullMatrix(
+ matrix, is_self_adjoint=True, is_positive_definite=True)
+ with self.test_session():
+ with self.assertRaisesOpError("not equal to its adjoint"):
+ operator.assert_self_adjoint().run()
+
+ def test_assert_positive_definite(self):
+ matrix = [[1., 1.], [1., 1.]]
+ operator = linalg.LinearOperatorFullMatrix(
+ matrix, is_self_adjoint=True, is_positive_definite=True)
+ with self.test_session():
+ # Cholesky decomposition may fail, so the error is not specific to
+ # non-singular.
+ with self.assertRaisesOpError(""):
+ operator.assert_positive_definite().run()
+
class NonSquareLinearOperatorFullMatrixTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator.py b/tensorflow/contrib/linalg/python/ops/linear_operator.py
index 8d0a1d7de2..605ab1511d 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator.py
@@ -21,12 +21,16 @@ from __future__ import print_function
import abc
import contextlib
+import numpy as np
+
from tensorflow.contrib import framework as contrib_framework
from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
__all__ = ["LinearOperator"]
@@ -456,14 +460,70 @@ class LinearOperator(object):
return self._cached_range_dimension_tensor
def _assert_non_singular(self):
+ """Private default implementation of _assert_non_singular."""
+ logging.warn(
+ "Using (possibly slow) default implementation of assert_non_singular."
+ " Requires conversion to a dense matrix and O(N^3) operations.")
+ if self._can_use_cholesky():
+ return self.assert_positive_definite()
+ else:
+ singular_values = linalg_ops.svd(
+ self._get_cached_dense_matrix(), compute_uv=False)
+ # TODO(langmore) Add .eig and .cond as methods.
+ cond = (math_ops.reduce_max(singular_values, axis=-1) /
+ math_ops.reduce_min(singular_values, axis=-1))
+ return check_ops.assert_less(
+ cond,
+ self._max_condition_number_to_be_non_singular(),
+ message="Singular matrix up to precision epsilon.")
raise NotImplementedError("assert_non_singular is not implemented.")
+ def _max_condition_number_to_be_non_singular(self):
+ """Return the maximum condition number that we consider nonsingular."""
+ with ops.name_scope("max_nonsingular_condition_number"):
+ dtype_eps = np.finfo(self.dtype.as_numpy_dtype).eps
+ eps = math_ops.cast(
+ math_ops.reduce_max([
+ 100.,
+ math_ops.cast(self.range_dimension_tensor(), self.dtype),
+ math_ops.cast(self.domain_dimension_tensor(), self.dtype)
+ ]), self.dtype) * dtype_eps
+ return 1. / eps
+
def assert_non_singular(self, name="assert_non_singular"):
- """Returns an `Op` that asserts this operator is non singular."""
+ """Returns an `Op` that asserts this operator is non singular.
+
+ This operator is considered non-singular if
+
+ ```
+ ConditionNumber < max{100, range_dimension, domain_dimension} * eps,
+ eps := np.finfo(self.dtype.as_numpy_dtype).eps
+ ```
+
+ Args:
+ name: A string name to prepend to created ops.
+
+ Returns:
+ An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
+ the operator is singular.
+ """
with self._name_scope(name):
return self._assert_non_singular()
def _assert_positive_definite(self):
+ """Default implementation of _assert_positive_definite."""
+ logging.warn(
+ "Using (possibly slow) default implementation of "
+ "assert_positive_definite."
+ " Requires conversion to a dense matrix and O(N^3) operations.")
+ # If the operator is self-adjoint, then checking that
+ # Cholesky decomposition succeeds + results in positive diag is necessary
+ # and sufficient.
+ if self.is_self_adjoint:
+ return check_ops.assert_positive(
+ array_ops.matrix_diag_part(self._get_cached_chol()),
+ message="Matrix was not positive definite.")
+ # We have no generic check for positive definite.
raise NotImplementedError("assert_positive_definite is not implemented.")
def assert_positive_definite(self, name="assert_positive_definite"):
@@ -477,16 +537,35 @@ class LinearOperator(object):
name: A name to give this `Op`.
Returns:
- An `Op` that asserts this operator is positive definite.
+ An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
+ the operator is not positive definite.
"""
with self._name_scope(name):
return self._assert_positive_definite()
def _assert_self_adjoint(self):
- raise NotImplementedError("assert_self_adjoint is not implemented.")
+ dense = self._get_cached_dense_matrix()
+ logging.warn(
+ "Using (possibly slow) default implementation of assert_self_adjoint."
+ " Requires conversion to a dense matrix.")
+ return check_ops.assert_equal(
+ dense,
+ linear_operator_util.matrix_adjoint(dense),
+ message="Matrix was not equal to its adjoint.")
def assert_self_adjoint(self, name="assert_self_adjoint"):
- """Returns an `Op` that asserts this operator is self-adjoint."""
+ """Returns an `Op` that asserts this operator is self-adjoint.
+
+ Here we check that this operator is *exactly* equal to its hermitian
+ transpose.
+
+ Args:
+ name: A string name to prepend to created ops.
+
+ Returns:
+ An `Assert` `Op`, that, when run, will raise an `InvalidArgumentError` if
+ the operator is not self-adjoint.
+ """
with self._name_scope(name):
return self._assert_self_adjoint()
@@ -526,6 +605,9 @@ class LinearOperator(object):
return self._apply(x, adjoint=adjoint, adjoint_arg=adjoint_arg)
def _determinant(self):
+ logging.warn(
+ "Using (possibly slow) default implementation of determinant."
+ " Requires conversion to a dense matrix and O(N^3) operations.")
if self._can_use_cholesky():
return math_ops.exp(self.log_abs_determinant())
return linalg_ops.matrix_determinant(self._matrix)
@@ -550,6 +632,9 @@ class LinearOperator(object):
return self._determinant()
def _log_abs_determinant(self):
+ logging.warn(
+ "Using (possibly slow) default implementation of determinant."
+ " Requires conversion to a dense matrix and O(N^3) operations.")
if self._can_use_cholesky():
diag = array_ops.matrix_diag_part(self._get_cached_chol())
return 2 * math_ops.reduce_sum(math_ops.log(diag), reduction_indices=[-1])
@@ -576,9 +661,13 @@ class LinearOperator(object):
return self._log_abs_determinant()
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
+ """Default implementation of _solve."""
if self.is_square is False:
raise NotImplementedError(
"Solve is not yet implemented for non-square operators.")
+ logging.warn(
+ "Using (possibly slow) default implementation of solve."
+ " Requires conversion to a dense matrix and O(N^3) operations.")
rhs = linear_operator_util.matrix_adjoint(rhs) if adjoint_arg else rhs
if self._can_use_cholesky():
return linalg_ops.cholesky_solve(self._get_cached_chol(), rhs)
@@ -643,6 +732,8 @@ class LinearOperator(object):
def _to_dense(self):
"""Generic and often inefficient implementation. Override often."""
+ logging.warn("Using (possibly slow) default implementation of to_dense."
+ " Converts by self.matmul(identity).")
if self.batch_shape.is_fully_defined():
batch_shape = self.batch_shape
else:
@@ -663,7 +754,7 @@ class LinearOperator(object):
def _diag_part(self):
"""Generic and often inefficient implementation. Override often."""
- return array_ops.matrix_diag_part(self.to_dense())
+ return array_ops.matrix_diag_part(self._get_cached_dense_matrix())
def diag_part(self, name="diag_part"):
"""Efficiently get the [batch] diagonal part of this operator.
@@ -695,7 +786,7 @@ class LinearOperator(object):
def _add_to_tensor(self, x):
# Override if a more efficient implementation is available.
- return self.to_dense() + x
+ return self._get_cached_dense_matrix() + x
def add_to_tensor(self, x, name="add_to_tensor"):
"""Add matrix represented by this operator to `x`. Equivalent to `A + x`.
@@ -723,8 +814,6 @@ class LinearOperator(object):
return self._cached_dense_matrix
def _get_cached_chol(self):
- if not self._can_use_cholesky():
- return None
if not hasattr(self, "_cached_chol"):
self._cached_chol = linalg_ops.cholesky(self._get_cached_dense_matrix())
return self._cached_chol