From 00954adfe5c7334de56a6cae1f0ad2f83988fcf3 Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Thu, 1 Dec 2016 13:35:52 -0800 Subject: LinearOperatorTriL, a lower triangular linear operator. linear_operator_util.py added as well Use linear_operator_util in LinearOperatorDiag, and a couple bug-fixes. Change: 140770818 --- tensorflow/contrib/linalg/BUILD | 26 +++ tensorflow/contrib/linalg/__init__.py | 2 + .../kernel_tests/linear_operator_diag_test.py | 15 +- .../kernel_tests/linear_operator_tril_test.py | 104 +++++++++++ .../kernel_tests/linear_operator_util_test.py | 90 +++++++++ .../linalg/python/ops/linear_operator_diag.py | 51 ++--- .../linalg/python/ops/linear_operator_test_util.py | 2 +- .../linalg/python/ops/linear_operator_tril.py | 207 +++++++++++++++++++++ .../linalg/python/ops/linear_operator_util.py | 72 +++++++ 9 files changed, 529 insertions(+), 40 deletions(-) create mode 100644 tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py create mode 100644 tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py create mode 100644 tensorflow/contrib/linalg/python/ops/linear_operator_tril.py create mode 100644 tensorflow/contrib/linalg/python/ops/linear_operator_util.py diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD index e3ed248dd5..20f4970f09 100644 --- a/tensorflow/contrib/linalg/BUILD +++ b/tensorflow/contrib/linalg/BUILD @@ -35,6 +35,32 @@ cuda_py_tests( shard_count = 5, ) +cuda_py_tests( + name = "linear_operator_tril_test", + size = "medium", + srcs = ["python/kernel_tests/linear_operator_tril_test.py"], + additional_deps = [ + ":linalg_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + shard_count = 5, +) + +cuda_py_tests( + name = "linear_operator_util_test", + size = "small", + srcs = ["python/kernel_tests/linear_operator_util_test.py"], + additional_deps = [ + ":linalg_py", + "//tensorflow:tensorflow_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], + shard_count = 5, +) + py_library( name = "linalg_py", srcs = ["__init__.py"] + glob(["python/ops/*.py"]), diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py index 3f73581bc3..d15ed052f2 100644 --- a/tensorflow/contrib/linalg/__init__.py +++ b/tensorflow/contrib/linalg/__init__.py @@ -30,6 +30,7 @@ Subclasses of `LinearOperator` provide a access to common methods on a ### Individual operators @@LinearOperatorDiag +@@LinearOperatorTriL """ from __future__ import absolute_import @@ -40,5 +41,6 @@ from __future__ import print_function from tensorflow.contrib.linalg.python.ops.linear_operator import * from tensorflow.contrib.linalg.python.ops.linear_operator_diag import * +from tensorflow.contrib.linalg.python.ops.linear_operator_tril import * # pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py index d03fb1d66f..09e7f880e0 100644 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_diag_test.py @@ -38,6 +38,7 @@ class LinearOperatorDiagTest( if dtype.is_complex: diag = tf.complex( diag, tf.random_normal(diag_shape, dtype=dtype.real_dtype)) + diag_ph = tf.placeholder(dtype=dtype) if use_placeholder: @@ -45,14 +46,14 @@ class LinearOperatorDiagTest( # diag is random and we want the same value used for both mat and # feed_dict. diag = diag.eval() - mat = tf.matrix_diag(diag) operator = linalg.LinearOperatorDiag(diag_ph) feed_dict = {diag_ph: diag} else: - mat = tf.matrix_diag(diag) operator = linalg.LinearOperatorDiag(diag) feed_dict = None + mat = tf.matrix_diag(diag) + return operator, mat, feed_dict def test_assert_positive_definite_raises_for_zero_eigenvalue(self): @@ -60,6 +61,9 @@ class LinearOperatorDiagTest( with self.test_session(): diag = [1.0, 0.0] operator = linalg.LinearOperatorDiag(diag) + + # is_self_adjoint should be auto-set for real diag. + self.assertTrue(operator.is_self_adjoint) with self.assertRaisesOpError("non-positive.*not positive definite"): operator.assert_positive_definite().run() @@ -69,6 +73,9 @@ class LinearOperatorDiagTest( diag_y = [0., 0.] # Imaginary eigenvalues should not matter. diag = tf.complex(diag_x, diag_y) operator = linalg.LinearOperatorDiag(diag) + + # is_self_adjoint should not be auto-set for complex diag. + self.assertTrue(operator.is_self_adjoint is None) with self.assertRaisesOpError("non-positive real.*not positive definite"): operator.assert_positive_definite().run() @@ -84,7 +91,7 @@ class LinearOperatorDiagTest( # Singlular matrix with one positive eigenvalue and one zero eigenvalue. with self.test_session(): diag = [1.0, 0.0] - operator = linalg.LinearOperatorDiag(diag) + operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) with self.assertRaisesOpError("Singular operator"): operator.assert_non_singular().run() @@ -124,7 +131,7 @@ class LinearOperatorDiagTest( # This LinearOperatorDiag will be brodacast to (2, 2, 3, 3) during solve # and apply with 'x' as the argument. diag = tf.random_uniform(shape=(2, 1, 3)) - operator = linalg.LinearOperatorDiag(diag) + operator = linalg.LinearOperatorDiag(diag, is_self_adjoint=True) self.assertAllEqual((2, 1, 3, 3), operator.shape) # Create a batch matrix with the broadcast shape of operator. diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py new file mode 100644 index 0000000000..35f1c4a48c --- /dev/null +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_tril_test.py @@ -0,0 +1,104 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.linalg.python.ops import linear_operator_test_util + + +linalg = tf.contrib.linalg +tf.set_random_seed(23) + + +class LinearOperatorTriLTest( + linear_operator_test_util.SquareLinearOperatorDerivedClassTest): + """Most tests done in the base class LinearOperatorDerivedClassTest.""" + + @property + def _dtypes_to_test(self): + # TODO(langmore) Test complex types once supported by + # matrix_triangular_solve. + return [tf.float32, tf.float64] + + def _operator_and_mat_and_feed_dict(self, shape, dtype, use_placeholder): + shape = list(shape) + diag_shape = shape[:-1] + + # Upper triangle will be ignored. + # Use a diagonal that ensures this matrix is well conditioned. + tril = tf.random_normal(shape=shape, dtype=dtype.real_dtype) + diag = tf.random_uniform( + shape=diag_shape, dtype=dtype.real_dtype, minval=2., maxval=3.) + if dtype.is_complex: + tril = tf.complex( + tril, tf.random_normal(shape, dtype=dtype.real_dtype)) + diag = tf.complex( + diag, tf.random_uniform( + shape=diag_shape, dtype=dtype.real_dtype, minval=2., maxval=3.)) + + tril = tf.matrix_set_diag(tril, diag) + + tril_ph = tf.placeholder(dtype=dtype) + + if use_placeholder: + # Evaluate the tril here because (i) you cannot feed a tensor, and (ii) + # tril is random and we want the same value used for both mat and + # feed_dict. + tril = tril.eval() + operator = linalg.LinearOperatorTriL(tril_ph) + feed_dict = {tril_ph: tril} + else: + operator = linalg.LinearOperatorTriL(tril) + feed_dict = None + + mat = tf.matrix_band_part(tril, -1, 0) + + return operator, mat, feed_dict + + def test_assert_positive_definite(self): + # Singlular matrix with one positive eigenvalue and one negative eigenvalue. + with self.test_session(): + tril = [[1., 0.], [1., -1.]] + operator = linalg.LinearOperatorTriL(tril) + with self.assertRaisesOpError("was not positive definite"): + operator.assert_positive_definite().run() + + def test_assert_non_singular(self): + # Singlular matrix with one positive eigenvalue and one zero eigenvalue. + with self.test_session(): + tril = [[1., 0.], [1., 0.]] + operator = linalg.LinearOperatorTriL(tril) + with self.assertRaisesOpError("Singular operator"): + operator.assert_non_singular().run() + + def test_is_x_flags(self): + # Matrix with one two positive eigenvalues. + tril = [[1., 0.], [1., 1.]] + operator = linalg.LinearOperatorTriL( + tril, + is_positive_definite=True, + is_non_singular=True, + is_self_adjoint=False) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertFalse(operator.is_self_adjoint) + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py new file mode 100644 index 0000000000..8e439070cc --- /dev/null +++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_util_test.py @@ -0,0 +1,90 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf + +from tensorflow.contrib.linalg.python.ops import linear_operator_util + + +linalg = tf.contrib.linalg +tf.set_random_seed(23) + + +class AssertZeroImagPartTest(tf.test.TestCase): + + def test_real_tensor_doesnt_raise(self): + x = tf.convert_to_tensor([0., 2, 3]) + with self.test_session(): + # Should not raise. + linear_operator_util.assert_zero_imag_part(x, message="ABC123").run() + + def test_complex_tensor_with_imag_zero_doesnt_raise(self): + x = tf.convert_to_tensor([1., 0, 3]) + y = tf.convert_to_tensor([0., 0, 0]) + z = tf.complex(x, y) + with self.test_session(): + # Should not raise. + linear_operator_util.assert_zero_imag_part(z, message="ABC123").run() + + def test_complex_tensor_with_nonzero_imag_raises(self): + x = tf.convert_to_tensor([1., 2, 0]) + y = tf.convert_to_tensor([1., 2, 0]) + z = tf.complex(x, y) + with self.test_session(): + with self.assertRaisesOpError("ABC123"): + linear_operator_util.assert_zero_imag_part(z, message="ABC123").run() + + +class AssertNoEntriesWithModulusZeroTest(tf.test.TestCase): + + def test_nonzero_real_tensor_doesnt_raise(self): + x = tf.convert_to_tensor([1., 2, 3]) + with self.test_session(): + # Should not raise. + linear_operator_util.assert_no_entries_with_modulus_zero( + x, message="ABC123").run() + + def test_nonzero_complex_tensor_doesnt_raise(self): + x = tf.convert_to_tensor([1., 0, 3]) + y = tf.convert_to_tensor([1., 2, 0]) + z = tf.complex(x, y) + with self.test_session(): + # Should not raise. + linear_operator_util.assert_no_entries_with_modulus_zero( + z, message="ABC123").run() + + def test_zero_real_tensor_raises(self): + x = tf.convert_to_tensor([1., 0, 3]) + with self.test_session(): + with self.assertRaisesOpError("ABC123"): + linear_operator_util.assert_no_entries_with_modulus_zero( + x, message="ABC123").run() + + def test_zero_complex_tensor_raises(self): + x = tf.convert_to_tensor([1., 2, 0]) + y = tf.convert_to_tensor([1., 2, 0]) + z = tf.complex(x, y) + with self.test_session(): + with self.assertRaisesOpError("ABC123"): + linear_operator_util.assert_no_entries_with_modulus_zero( + z, message="ABC123").run() + + +if __name__ == "__main__": + tf.test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py index f65ed9a6c8..58a891710c 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_diag.py @@ -19,11 +19,11 @@ from __future__ import division from __future__ import print_function from tensorflow.contrib.linalg.python.ops import linear_operator +from tensorflow.contrib.linalg.python.ops import linear_operator_util from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.ops import check_ops -from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops __all__ = ["LinearOperatorDiag",] @@ -111,7 +111,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): def __init__(self, diag, is_non_singular=None, - is_self_adjoint=True, + is_self_adjoint=None, is_positive_definite=None, name="LinearOperatorDiag"): """Initialize a `LinearOperatorDiag`. @@ -119,11 +119,10 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Args: diag: Shape `[B1,...,Bb, N]` `Tensor` with `b >= 0` `N >= 0`. The diagonal of the operator. Allowed dtypes: `float32`, `float64`, - `complex64`, `complex128`. + `complex64`, `complex128`. is_non_singular: Expect that this operator is non-singular. is_self_adjoint: Expect that this operator is equal to its hermitian - transpose. Since this is a real (not complex) diagonal operator, it is - always self adjoint. + transpose. If `diag.dtype` is real, this is auto-set to `True`. is_positive_definite: Expect that this operator is positive definite, meaning the real part of all eigenvalues is positive. We do not require the operator to be self-adjoint to be positive-definite. See: @@ -133,7 +132,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): Raises: TypeError: If `diag.dtype` is not an allowed type. - ValueError: If `is_self_adjoint` is not `True`. + ValueError: If `diag.dtype` is real, and `is_self_adjoint` is not `True`. """ allowed_dtypes = [ @@ -146,8 +145,13 @@ class LinearOperatorDiag(linear_operator.LinearOperator): raise TypeError( "Argument diag must have dtype in %s. Found: %s" % (allowed_dtypes, dtype)) - if dtype.is_floating and not is_self_adjoint: - raise ValueError("A real diagonal operator is always self adjoint.") + + # Check and auto-set hints. + if not dtype.is_complex: + if is_self_adjoint is False: + raise ValueError("A real diagonal operator is always self adjoint.") + else: + is_self_adjoint = True super(LinearOperatorDiag, self).__init__( dtype=dtype, @@ -168,17 +172,9 @@ class LinearOperatorDiag(linear_operator.LinearOperator): return array_ops.concat(0, (d_shape, [k])) def _assert_non_singular(self): - if self.dtype.is_complex: - should_be_nonzero = math_ops.complex_abs(self._diag) - else: - should_be_nonzero = self._diag - - nonzero_diag = math_ops.reduce_all( - math_ops.logical_not(math_ops.equal(should_be_nonzero, 0))) - - return control_flow_ops.Assert( - nonzero_diag, - data=["Singular operator: diag contained zero values.", self._diag]) + return linear_operator_util.assert_no_entries_with_modulus_zero( + self._diag, + message="Singular operator: Diagonal contained zero values.") def _assert_positive_definite(self): if self.dtype.is_complex: @@ -195,7 +191,7 @@ class LinearOperatorDiag(linear_operator.LinearOperator): message=message) def _assert_self_adjoint(self): - return _assert_imag_part_zero( + return linear_operator_util.assert_zero_imag_part( self._diag, message=( "This diagonal operator contained non-zero imaginary values. " @@ -225,18 +221,3 @@ class LinearOperatorDiag(linear_operator.LinearOperator): x_diag = array_ops.matrix_diag_part(x) new_diag = self._diag + x_diag return array_ops.matrix_set_diag(x, new_diag) - - -def _assert_imag_part_zero(x, message=None): - """Assert that floating or complex 'x' is real.""" - dtype = x.dtype.base_dtype - if dtype.is_floating: - return control_flow_ops.no_op() - - if not dtype.is_complex: - raise TypeError( - "imag_part_zero only handles float or complex types. Found: %s" - % dtype) - - zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) - return check_ops.assert_equal(zero, math_ops.imag(x), message=message) diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py index 20136bfbd0..5f0f1e9bb1 100644 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_test_util.py @@ -48,7 +48,7 @@ class LinearOperatorDerivedClassTest(tf.test.TestCase): @property def _dtypes_to_test(self): - # TODO(langmore) Test tf.float16 once tf.matrix_diag works in 16bit. + # TODO(langmore) Test tf.float16 once tf.matrix_solve works in 16bit. return [tf.float32, tf.float64, tf.complex64, tf.complex128] @abc.abstractproperty diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py new file mode 100644 index 0000000000..ce54fa3c20 --- /dev/null +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_tril.py @@ -0,0 +1,207 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`LinearOperator` acting like a lower triangular matrix.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.linalg.python.ops import linear_operator +from tensorflow.contrib.linalg.python.ops import linear_operator_util +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops import math_ops + +__all__ = ["LinearOperatorTriL",] + + +class LinearOperatorTriL(linear_operator.LinearOperator): + """`LinearOperator` acting like a [batch] square lower triangular matrix. + + This operator acts like a [batch] matrix `A` with shape + `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a + batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is + an `N x N` matrix. + + `LinearOperatorTriL` is initialized with a `Tensor` having dimensions + `[B1,...,Bb, N, N]`. The upper triangle of the last two dimensions is ignored. + + ```python + # Create a 2 x 2 lower-triangular linear operator. + tril = [[1., 2.], [3., 4.]] + operator = LinearOperatorTriL(tril) + + # The upper triangle is ignored. + operator.to_dense() + ==> [[1., 0.] + [3., 4.]] + + operator.shape + ==> [2, 2] + + operator.log_determinant() + ==> scalar Tensor + + x = ... Shape [2, 4] Tensor + operator.apply(x) + ==> Shape [2, 4] Tensor + + # Create a [2, 3] batch of 4 x 4 linear operators. + tril = tf.random_normal(shape=[2, 3, 4, 4]) + operator = LinearOperatorTriL(tril) + + # Create a shape [2, 1, 4, 2] vector. Note that this shape is compatible + # since the batch dimensions, [2, 1], are brodcast to + # operator.batch_shape = [2, 3]. + y = tf.random_normal(shape=[2, 1, 4, 2]) + x = operator.solve(y) + ==> operator.apply(x) = y + ``` + + ### Shape compatibility + + This operator acts on [batch] matrix with compatible shape. + `x` is a batch matrix with compatible shape for `apply` and `solve` if + + ``` + operator.shape = [B1,...,Bb] + [N, N], with b >= 0 + x.shape = [B1,...,Bb] + [N, R], with R >= 0. + ``` + + ### Performance + + Suppose `operator` is a `LinearOperatorTriL` of shape `[N, N]`, + and `x.shape = [N, R]`. Then + + * `operator.apply(x)` involves `N^2 * R` multiplications. + * `operator.solve(x)` involves `N * R` size `N` back-substitutions. + * `operator.determinant()` involves a size `N` `reduce_prod`. + + If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and + `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`. + + ### Matrix property hints + + This `LinearOperator` is initialized with boolean flags of the form `is_X`, + for `X = non_singular, self_adjoint` etc... + These have the following meaning + * If `is_X == True`, callers should expect the operator to have the + property `X`. This is a promise that should be fulfilled, but is *not* a + runtime assert. For example, finite floating point precision may result + in these promises being violated. + * If `is_X == False`, callers should expect the operator to not have `X`. + * If `is_X == None` (the default), callers should have no expectation either + way. + """ + + def __init__(self, + tril, + is_non_singular=None, + is_self_adjoint=None, + is_positive_definite=None, + name="LinearOperatorTriL"): + """Initialize a `LinearOperatorTriL`. + + Args: + tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. + The lower triangular part of `tril` defines this operator. The strictly + upper triangle is ignored. Allowed dtypes: `float32`, `float64`. + is_non_singular: Expect that this operator is non-singular. + This operator is non-singular if and only if its diagonal elements are + all non-zero. + is_self_adjoint: Expect that this operator is equal to its hermitian + transpose. This operator is self-adjoint only if it is diagonal with + real-valued diagonal entries. In this case it is advised to use + `LinearOperatorDiag`. + is_positive_definite: Expect that this operator is positive definite, + meaning the real part of all eigenvalues is positive. We do not require + the operator to be self-adjoint to be positive-definite. See: + https://en.wikipedia.org/wiki/Positive-definite_matrix + #Extension_for_non_symmetric_matrices + name: A name for this `LinearOperator`. + + Raises: + TypeError: If `diag.dtype` is not an allowed type. + """ + + # TODO(langmore) Add complex types once matrix_triangular_solve works for + # them. + allowed_dtypes = [dtypes.float32, dtypes.float64] + + with ops.name_scope(name, values=[tril]): + self._tril = array_ops.matrix_band_part(tril, -1, 0) + self._diag = array_ops.matrix_diag_part(self._tril) + + dtype = self._tril.dtype + if dtype not in allowed_dtypes: + raise TypeError( + "Argument diag must have dtype in %s. Found: %s" + % (allowed_dtypes, dtype)) + + super(LinearOperatorTriL, self).__init__( + dtype=self._tril.dtype, + graph_parents=[self._tril], + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite, + name=name) + + def _shape(self): + return self._tril.get_shape() + + def _shape_dynamic(self): + return array_ops.shape(self._tril) + + def _assert_non_singular(self): + return linear_operator_util.assert_no_entries_with_modulus_zero( + self._diag, + message="Singular operator: Diagonal contained zero values.") + + def _assert_positive_definite(self): + if self.dtype.is_complex: + message = ( + "Diagonal operator had diagonal entries with non-positive real part, " + "thus was not positive definite.") + else: + message = ( + "Real diagonal operator had non-positive diagonal entries, " + "thus was not positive definite.") + + return check_ops.assert_positive( + math_ops.real(self._diag), + message=message) + + def _apply(self, x, adjoint=False): + return math_ops.matmul(self._tril, x, adjoint_a=adjoint) + + def _determinant(self): + return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) + + def _log_abs_determinant(self): + return math_ops.reduce_sum( + math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) + + def _solve(self, rhs, adjoint=False): + return linalg_ops.matrix_triangular_solve( + self._tril, rhs, lower=True, adjoint=adjoint) + + def _to_dense(self): + return self._tril + + def _add_to_tensor(self, x): + return self._tril + x diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_util.py b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py new file mode 100644 index 0000000000..06140ef4a2 --- /dev/null +++ b/tensorflow/contrib/linalg/python/ops/linear_operator_util.py @@ -0,0 +1,72 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Internal utilities for `LinearOperator` classes.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.framework import ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import math_ops + + +def assert_no_entries_with_modulus_zero( + x, message=None, name="assert_no_entries_with_modulus_zero"): + """Returns `Op` that asserts Tensor `x` has no entries with modulus zero. + + Args: + x: Numeric `Tensor`, real, integer, or complex. + message: A string message to prepend to failure message. + name: A name to give this `Op`. + + Returns: + An `Op` that asserts `x` has no entries with modulus zero. + """ + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + dtype = x.dtype.base_dtype + + if dtype.is_complex: + should_be_nonzero = math_ops.complex_abs(x) + else: + should_be_nonzero = math_ops.abs(x) + + zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) + + return check_ops.assert_less(zero, should_be_nonzero, message=message) + + +def assert_zero_imag_part(x, message=None, name="assert_zero_imag_part"): + """Returns `Op` that asserts Tensor `x` has no non-zero imaginary parts. + + Args: + x: Numeric `Tensor`, real, integer, or complex. + message: A string message to prepend to failure message. + name: A name to give this `Op`. + + Returns: + An `Op` that asserts `x` has no entries with modulus zero. + """ + with ops.name_scope(name, values=[x]): + x = ops.convert_to_tensor(x, name="x") + dtype = x.dtype.base_dtype + + if dtype.is_floating: + return control_flow_ops.no_op() + + zero = ops.convert_to_tensor(0, dtype=dtype.real_dtype) + return check_ops.assert_equal(zero, math_ops.imag(x), message=message) -- cgit v1.2.3