aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2016-12-01 13:35:52 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 14:00:57 -0800
commit00954adfe5c7334de56a6cae1f0ad2f83988fcf3 (patch)
tree2a7fa44a1b6a52837628bb5940b69a918eb83ac9
parentb52307ca3f507508860ad2e5188cbdf87f9429e2 (diff)
LinearOperatorTriL, a lower triangular linear operator.
linear_operator_util.py added as well Use linear_operator_util in LinearOperatorDiag, and a couple bug-fixes. Change: 140770818
-rw-r--r--tensorflow/contrib/linalg/BUILD26
-rw-r--r--tensorflow/contrib/linalg/__init__.py2
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py15
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py104
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py90
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_diag.py51
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py2
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_tril.py207
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_util.py72
9 files changed, 529 insertions, 40 deletions
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index e3ed248dd5..20f4970f09 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -35,6 +35,32 @@ cuda_py_tests(
shard_count = 5,
)
+cuda_py_tests(
+ name = "linear_operator_tril_test",
+ size = "medium",
+ srcs = ["python/kernel_tests/linear_operator_tril_test.py"],
+ additional_deps = [
+ ":linalg_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 5,
+)
+
+cuda_py_tests(
+ name = "linear_operator_util_test",
+ size = "small",
+ srcs = ["python/kernel_tests/linear_operator_util_test.py"],
+ additional_deps = [
+ ":linalg_py",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 5,
+)
+
py_library(
name = "linalg_py",
srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index 3f73581bc3..d15ed052f2 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -30,6 +30,7 @@ Subclasses of `LinearOperator` provide a access to common methods on a
### Individual operators
@@LinearOperatorDiag
+@@LinearOperatorTriL
"""
from __future__ import absolute_import
@@ -40,5 +41,6 @@ from __future__ import print_function
from tensorflow.contrib.linalg.python.ops.linear_operator import *
from tensorflow.contrib.linalg.python.ops.linear_operator_diag import *
+from tensorflow.contrib.linalg.python.ops.linear_operator_tril import *
# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
index d03fb1d66f..09e7f880e0 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py
@@ -38,6 +38,7 @@ class LinearOperatorDiagTest(
if dtype.is_complex:
diag = tf.complex(
diag, tf.random_normal(diag_shape, dtype=dtype.real_dtype))
+
diag_ph = tf.placeholder(dtype=dtype)
if use_placeholder:
@@ -45,14 +46,14 @@ class LinearOperatorDiagTest(
# diag is random and we want the same value used for both mat and
# feed_dict.
diag = diag.eval()
- mat = tf.matrix_diag(diag)
operator = linalg.LinearOperatorDiag(diag_ph)
feed_dict = {diag_ph: diag}
else:
- mat = tf.matrix_diag(diag)
operator = linalg.LinearOperatorDiag(diag)
feed_dict = None
+ mat = tf.matrix_diag(diag)
+
return operator, mat, feed_dict
def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
@@ -60,6 +61,9 @@ class LinearOperatorDiagTest(
with self.test_session():
diag = [1.0, 0.0]
operator = linalg.LinearOperatorDiag(diag)
+
+ # is_self_adjoint should be auto-set for real diag.
+ self.assertTrue(operator.is_self_adjoint)
with self.assertRaisesOpError("non-positive.*not positive definite"):
operator.assert_positive_definite().run()
@@ -69,6 +73,9 @@ class LinearOperatorDiagTest(
diag_y = [0., 0.] # Imaginary eigenvalues should not matter.
diag = tf.complex(diag_x, diag_y)
operator = linalg.LinearOperatorDiag(diag)
+
+ # is_self_adjoint should not be auto-set for complex diag.
+ self.assertTrue(operator.is_self_adjoint is None)
with self.assertRaisesOpError("non-positive real.*not positive definite"):
operator.assert_positive_definite().run()
@@ -84,7 +91,7 @@ class LinearOperatorDiagTest(
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
with self.test_session():
diag = [1.0, 0.0]
- operator = linalg.LinearOperatorDiag(diag)
+ operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
@@ -124,7 +131,7 @@ class LinearOperatorDiagTest(
# This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve
# and apply with 'x' as the argument.
diag = tf.random_uniform(shape=(2, 1, 3))
- operator = linalg.LinearOperatorDiag(diag)
+ operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True)
self.assertAllEqual((2, 1, 3, 3), operator.shape)
# Create a batch matrix with the broadcast shape of operator.
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py
new file mode 100644
index 0000000000..35f1c4a48c
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py
@@ -0,0 +1,104 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.linalg.python.ops import linear_operator_test_util
+
+
+linalg = tf.contrib.linalg
+tf.set_random_seed(23)
+
+
+class LinearOperatorTriLTest(
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Most tests done in the base class LinearOperatorDerivedClassTest."""
+
+ @property
+ def _dtypes_to_test(self):
+ # TODO(langmore) Test complex types once supported by
+ # matrix_triangular_solve.
+ return [tf.float32, tf.float64]
+
+ def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder):
+ shape = list(shape)
+ diag_shape = shape[:-1]
+
+ # Upper triangle will be ignored.
+ # Use a diagonal that ensures this matrix is well conditioned.
+ tril = tf.random_normal(shape=shape, dtype=dtype.real_dtype)
+ diag = tf.random_uniform(
+ shape=diag_shape, dtype=dtype.real_dtype, minval=2., maxval=3.)
+ if dtype.is_complex:
+ tril = tf.complex(
+ tril, tf.random_normal(shape, dtype=dtype.real_dtype))
+ diag = tf.complex(
+ diag, tf.random_uniform(
+ shape=diag_shape, dtype=dtype.real_dtype, minval=2., maxval=3.))
+
+ tril = tf.matrix_set_diag(tril, diag)
+
+ tril_ph = tf.placeholder(dtype=dtype)
+
+ if use_placeholder:
+ # Evaluate the tril here because (i) you cannot feed a tensor, and (ii)
+ # tril is random and we want the same value used for both mat and
+ # feed_dict.
+ tril = tril.eval()
+ operator = linalg.LinearOperatorTriL(tril_ph)
+ feed_dict = {tril_ph: tril}
+ else:
+ operator = linalg.LinearOperatorTriL(tril)
+ feed_dict = None
+
+ mat = tf.matrix_band_part(tril, -1, 0)
+
+ return operator, mat, feed_dict
+
+ def test_assert_positive_definite(self):
+ # Singlular matrix with one positive eigenvalue and one negative eigenvalue.
+ with self.test_session():
+ tril = [[1., 0.], [1., -1.]]
+ operator = linalg.LinearOperatorTriL(tril)
+ with self.assertRaisesOpError("was not positive definite"):
+ operator.assert_positive_definite().run()
+
+ def test_assert_non_singular(self):
+ # Singlular matrix with one positive eigenvalue and one zero eigenvalue.
+ with self.test_session():
+ tril = [[1., 0.], [1., 0.]]
+ operator = linalg.LinearOperatorTriL(tril)
+ with self.assertRaisesOpError("Singular operator"):
+ operator.assert_non_singular().run()
+
+ def test_is_x_flags(self):
+ # Matrix with one two positive eigenvalues.
+ tril = [[1., 0.], [1., 1.]]
+ operator = linalg.LinearOperatorTriL(
+ tril,
+ is_positive_definite=True,
+ is_non_singular=True,
+ is_self_adjoint=False)
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertFalse(operator.is_self_adjoint)
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
new file mode 100644
index 0000000000..8e439070cc
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py
@@ -0,0 +1,90 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import tensorflow as tf
+
+from tensorflow.contrib.linalg.python.ops import linear_operator_util
+
+
+linalg = tf.contrib.linalg
+tf.set_random_seed(23)
+
+
+class AssertZeroImagPartTest(tf.test.TestCase):
+
+ def test_real_tensor_doesnt_raise(self):
+ x = tf.convert_to_tensor([0., 2, 3])
+ with self.test_session():
+ # Should not raise.
+ linear_operator_util.assert_zero_imag_part(x, message="ABC123").run()
+
+ def test_complex_tensor_with_imag_zero_doesnt_raise(self):
+ x = tf.convert_to_tensor([1., 0, 3])
+ y = tf.convert_to_tensor([0., 0, 0])
+ z = tf.complex(x, y)
+ with self.test_session():
+ # Should not raise.
+ linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
+
+ def test_complex_tensor_with_nonzero_imag_raises(self):
+ x = tf.convert_to_tensor([1., 2, 0])
+ y = tf.convert_to_tensor([1., 2, 0])
+ z = tf.complex(x, y)
+ with self.test_session():
+ with self.assertRaisesOpError("ABC123"):
+ linear_operator_util.assert_zero_imag_part(z, message="ABC123").run()
+
+
+class AssertNoEntriesWithModulusZeroTest(tf.test.TestCase):
+
+ def test_nonzero_real_tensor_doesnt_raise(self):
+ x = tf.convert_to_tensor([1., 2, 3])
+ with self.test_session():
+ # Should not raise.
+ linear_operator_util.assert_no_entries_with_modulus_zero(
+ x, message="ABC123").run()
+
+ def test_nonzero_complex_tensor_doesnt_raise(self):
+ x = tf.convert_to_tensor([1., 0, 3])
+ y = tf.convert_to_tensor([1., 2, 0])
+ z = tf.complex(x, y)
+ with self.test_session():
+ # Should not raise.
+ linear_operator_util.assert_no_entries_with_modulus_zero(
+ z, message="ABC123").run()
+
+ def test_zero_real_tensor_raises(self):
+ x = tf.convert_to_tensor([1., 0, 3])
+ with self.test_session():
+ with self.assertRaisesOpError("ABC123"):
+ linear_operator_util.assert_no_entries_with_modulus_zero(
+ x, message="ABC123").run()
+
+ def test_zero_complex_tensor_raises(self):
+ x = tf.convert_to_tensor([1., 2, 0])
+ y = tf.convert_to_tensor([1., 2, 0])
+ z = tf.complex(x, y)
+ with self.test_session():
+ with self.assertRaisesOpError("ABC123"):
+ linear_operator_util.assert_no_entries_with_modulus_zero(
+ z, message="ABC123").run()
+
+
+if __name__ == "__main__":
+ tf.test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
index f65ed9a6c8..58a891710c 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py
@@ -19,11 +19,11 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.linalg.python.ops import linear_operator
+from tensorflow.contrib.linalg.python.ops import linear_operator_util
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
__all__ = ["LinearOperatorDiag",]
@@ -111,7 +111,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
def __init__(self,
diag,
is_non_singular=None,
- is_self_adjoint=True,
+ is_self_adjoint=None,
is_positive_definite=None,
name="LinearOperatorDiag"):
"""Initialize a `LinearOperatorDiag`.
@@ -119,11 +119,10 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
Args:
diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`.
The diagonal of the operator. Allowed dtypes: `float32`, `float64`,
- `complex64`, `complex128`.
+ `complex64`, `complex128`.
is_non_singular: Expect that this operator is non-singular.
is_self_adjoint: Expect that this operator is equal to its hermitian
- transpose. Since this is a real (not complex) diagonal operator, it is
- always self adjoint.
+ transpose. If `diag.dtype` is real, this is auto-set to `True`.
is_positive_definite: Expect that this operator is positive definite,
meaning the real part of all eigenvalues is positive. We do not require
the operator to be self-adjoint to be positive-definite. See:
@@ -133,7 +132,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
Raises:
TypeError: If `diag.dtype` is not an allowed type.
- ValueError: If `is_self_adjoint` is not `True`.
+ ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`.
"""
allowed_dtypes = [
@@ -146,8 +145,13 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
raise TypeError(
"Argument diag must have dtype in %s. Found: %s"
% (allowed_dtypes, dtype))
- if dtype.is_floating and not is_self_adjoint:
- raise ValueError("A real diagonal operator is always self adjoint.")
+
+ # Check and auto-set hints.
+ if not dtype.is_complex:
+ if is_self_adjoint is False:
+ raise ValueError("A real diagonal operator is always self adjoint.")
+ else:
+ is_self_adjoint = True
super(LinearOperatorDiag, self).__init__(
dtype=dtype,
@@ -168,17 +172,9 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return array_ops.concat(0, (d_shape, [k]))
def _assert_non_singular(self):
- if self.dtype.is_complex:
- should_be_nonzero = math_ops.complex_abs(self._diag)
- else:
- should_be_nonzero = self._diag
-
- nonzero_diag = math_ops.reduce_all(
- math_ops.logical_not(math_ops.equal(should_be_nonzero, 0)))
-
- return control_flow_ops.Assert(
- nonzero_diag,
- data=["Singular operator: diag contained zero values.", self._diag])
+ return linear_operator_util.assert_no_entries_with_modulus_zero(
+ self._diag,
+ message="Singular operator: Diagonal contained zero values.")
def _assert_positive_definite(self):
if self.dtype.is_complex:
@@ -195,7 +191,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
message=message)
def _assert_self_adjoint(self):
- return _assert_imag_part_zero(
+ return linear_operator_util.assert_zero_imag_part(
self._diag,
message=(
"This diagonal operator contained non-zero imaginary values. "
@@ -225,18 +221,3 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
x_diag = array_ops.matrix_diag_part(x)
new_diag = self._diag + x_diag
return array_ops.matrix_set_diag(x, new_diag)
-
-
-def _assert_imag_part_zero(x, message=None):
- """Assert that floating or complex 'x' is real."""
- dtype = x.dtype.base_dtype
- if dtype.is_floating:
- return control_flow_ops.no_op()
-
- if not dtype.is_complex:
- raise TypeError(
- "imag_part_zero only handles float or complex types. Found: %s"
- % dtype)
-
- zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
- return check_ops.assert_equal(zero, math_ops.imag(x), message=message)
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
index 20136bfbd0..5f0f1e9bb1 100644
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py
@@ -48,7 +48,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase):
@property
def _dtypes_to_test(self):
- # TODO(langmore) Test tf.float16 once tf.matrix_diag works in 16bit.
+ # TODO(langmore) Test tf.float16 once tf.matrix_solve works in 16bit.
return [tf.float32, tf.float64, tf.complex64, tf.complex128]
@abc.abstractproperty
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
new file mode 100644
index 0000000000..ce54fa3c20
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py
@@ -0,0 +1,207 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""`LinearOperator` acting like a lower triangular matrix."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.linalg.python.ops import linear_operator
+from tensorflow.contrib.linalg.python.ops import linear_operator_util
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops import math_ops
+
+__all__ = ["LinearOperatorTriL",]
+
+
+class LinearOperatorTriL(linear_operator.LinearOperator):
+ """`LinearOperator` acting like a [batch] square lower triangular matrix.
+
+ This operator acts like a [batch] matrix `A` with shape
+ `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
+ batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ an `N x N` matrix.
+
+ `LinearOperatorTriL` is initialized with a `Tensor` having dimensions
+ `[B1,...,Bb, N, N]`. The upper triangle of the last two dimensions is ignored.
+
+ ```python
+ # Create a 2 x 2 lower-triangular linear operator.
+ tril = [[1., 2.], [3., 4.]]
+ operator = LinearOperatorTriL(tril)
+
+ # The upper triangle is ignored.
+ operator.to_dense()
+ ==> [[1., 0.]
+ [3., 4.]]
+
+ operator.shape
+ ==> [2, 2]
+
+ operator.log_determinant()
+ ==> scalar Tensor
+
+ x = ... Shape [2, 4] Tensor
+ operator.apply(x)
+ ==> Shape [2, 4] Tensor
+
+ # Create a [2, 3] batch of 4 x 4 linear operators.
+ tril = tf.random_normal(shape=[2, 3, 4, 4])
+ operator = LinearOperatorTriL(tril)
+
+ # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible
+ # since the batch dimensions, [2, 1], are brodcast to
+ # operator.batch_shape = [2, 3].
+ y = tf.random_normal(shape=[2, 1, 4, 2])
+ x = operator.solve(y)
+ ==> operator.apply(x) = y
+ ```
+
+ ### Shape compatibility
+
+ This operator acts on [batch] matrix with compatible shape.
+ `x` is a batch matrix with compatible shape for `apply` and `solve` if
+
+ ```
+ operator.shape = [B1,...,Bb] + [N, N], with b >= 0
+ x.shape = [B1,...,Bb] + [N, R], with R >= 0.
+ ```
+
+ ### Performance
+
+ Suppose `operator` is a `LinearOperatorTriL` of shape `[N, N]`,
+ and `x.shape = [N, R]`. Then
+
+ * `operator.apply(x)` involves `N^2 * R` multiplications.
+ * `operator.solve(x)` involves `N * R` size `N` back-substitutions.
+ * `operator.determinant()` involves a size `N` `reduce_prod`.
+
+ If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+ `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ ### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint` etc...
+ These have the following meaning
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ tril,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ name="LinearOperatorTriL"):
+ """Initialize a `LinearOperatorTriL`.
+
+ Args:
+ tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
+ The lower triangular part of `tril` defines this operator. The strictly
+ upper triangle is ignored. Allowed dtypes: `float32`, `float64`.
+ is_non_singular: Expect that this operator is non-singular.
+ This operator is non-singular if and only if its diagonal elements are
+ all non-zero.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. This operator is self-adjoint only if it is diagonal with
+ real-valued diagonal entries. In this case it is advised to use
+ `LinearOperatorDiag`.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix
+ #Extension_for_non_symmetric_matrices
+ name: A name for this `LinearOperator`.
+
+ Raises:
+ TypeError: If `diag.dtype` is not an allowed type.
+ """
+
+ # TODO(langmore) Add complex types once matrix_triangular_solve works for
+ # them.
+ allowed_dtypes = [dtypes.float32, dtypes.float64]
+
+ with ops.name_scope(name, values=[tril]):
+ self._tril = array_ops.matrix_band_part(tril, -1, 0)
+ self._diag = array_ops.matrix_diag_part(self._tril)
+
+ dtype = self._tril.dtype
+ if dtype not in allowed_dtypes:
+ raise TypeError(
+ "Argument diag must have dtype in %s. Found: %s"
+ % (allowed_dtypes, dtype))
+
+ super(LinearOperatorTriL, self).__init__(
+ dtype=self._tril.dtype,
+ graph_parents=[self._tril],
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ name=name)
+
+ def _shape(self):
+ return self._tril.get_shape()
+
+ def _shape_dynamic(self):
+ return array_ops.shape(self._tril)
+
+ def _assert_non_singular(self):
+ return linear_operator_util.assert_no_entries_with_modulus_zero(
+ self._diag,
+ message="Singular operator: Diagonal contained zero values.")
+
+ def _assert_positive_definite(self):
+ if self.dtype.is_complex:
+ message = (
+ "Diagonal operator had diagonal entries with non-positive real part, "
+ "thus was not positive definite.")
+ else:
+ message = (
+ "Real diagonal operator had non-positive diagonal entries, "
+ "thus was not positive definite.")
+
+ return check_ops.assert_positive(
+ math_ops.real(self._diag),
+ message=message)
+
+ def _apply(self, x, adjoint=False):
+ return math_ops.matmul(self._tril, x, adjoint_a=adjoint)
+
+ def _determinant(self):
+ return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
+
+ def _log_abs_determinant(self):
+ return math_ops.reduce_sum(
+ math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+
+ def _solve(self, rhs, adjoint=False):
+ return linalg_ops.matrix_triangular_solve(
+ self._tril, rhs, lower=True, adjoint=adjoint)
+
+ def _to_dense(self):
+ return self._tril
+
+ def _add_to_tensor(self, x):
+ return self._tril + x
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
new file mode 100644
index 0000000000..06140ef4a2
--- /dev/null
+++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py
@@ -0,0 +1,72 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Internal utilities for `LinearOperator` classes."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+
+
+def assert_no_entries_with_modulus_zero(
+ x, message=None, name="assert_no_entries_with_modulus_zero"):
+ """Returns `Op` that asserts Tensor `x` has no entries with modulus zero.
+
+ Args:
+ x: Numeric `Tensor`, real, integer, or complex.
+ message: A string message to prepend to failure message.
+ name: A name to give this `Op`.
+
+ Returns:
+ An `Op` that asserts `x` has no entries with modulus zero.
+ """
+ with ops.name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ dtype = x.dtype.base_dtype
+
+ if dtype.is_complex:
+ should_be_nonzero = math_ops.complex_abs(x)
+ else:
+ should_be_nonzero = math_ops.abs(x)
+
+ zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
+
+ return check_ops.assert_less(zero, should_be_nonzero, message=message)
+
+
+def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"):
+ """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts.
+
+ Args:
+ x: Numeric `Tensor`, real, integer, or complex.
+ message: A string message to prepend to failure message.
+ name: A name to give this `Op`.
+
+ Returns:
+ An `Op` that asserts `x` has no entries with modulus zero.
+ """
+ with ops.name_scope(name, values=[x]):
+ x = ops.convert_to_tensor(x, name="x")
+ dtype = x.dtype.base_dtype
+
+ if dtype.is_floating:
+ return control_flow_ops.no_op()
+
+ zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype)
+ return check_ops.assert_equal(zero, math_ops.imag(x), message=message)