diff options
author | 2018-09-12 09:45:53 -0700 | |
---|---|---|
committer | 2018-09-12 09:49:50 -0700 | |
commit | 26509bf4e202c09da4f0b00d43ebddf87368a0f2 (patch) | |
tree | 200e9080e06200dc5512239d25fd2f2c7b18a899 | |
parent | 1c4fceab7dc09cab18c0def098320d6c52d2e514 (diff) |
Add linear_operator_addition to tensorflow/python/. A subsequent CL
will remove this from contrib.
linear_operator_addition is hidden from the public API.
PiperOrigin-RevId: 212655087
3 files changed, 860 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD index f4ec3e3996..be2e31cb5a 100644 --- a/tensorflow/python/kernel_tests/linalg/BUILD +++ b/tensorflow/python/kernel_tests/linalg/BUILD @@ -25,6 +25,22 @@ cuda_py_test( ) cuda_py_test( + name = "linear_operator_addition_test", + size = "small", + srcs = ["linear_operator_addition_test.py"], + additional_deps = [ + "//tensorflow/python/ops/linalg", + "//third_party/py/numpy", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:math_ops", + "//tensorflow/python:platform_test", + ], +) + +cuda_py_test( name = "linear_operator_block_diag_test", size = "medium", srcs = ["linear_operator_block_diag_test.py"], diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py new file mode 100644 index 0000000000..7c79fedf65 --- /dev/null +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py @@ -0,0 +1,412 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import linalg_ops +from tensorflow.python.ops.linalg import linalg as linalg_lib +from tensorflow.python.ops.linalg import linear_operator_addition +from tensorflow.python.platform import test + +linalg = linalg_lib +random_seed.set_random_seed(23) +rng = np.random.RandomState(0) + +add_operators = linear_operator_addition.add_operators + + +# pylint: disable=unused-argument +class _BadAdder(linear_operator_addition._Adder): + """Adder that will fail if used.""" + + def can_add(self, op1, op2): + raise AssertionError("BadAdder.can_add called!") + + def _add(self, op1, op2, operator_name, hints): + raise AssertionError("This line should not be reached") + + +# pylint: enable=unused-argument + + +class LinearOperatorAdditionCorrectnessTest(test.TestCase): + """Tests correctness of addition with combinations of a few Adders. + + Tests here are done with the _DEFAULT_ADDITION_TIERS, which means + add_operators should reduce all operators resulting in one single operator. + + This shows that we are able to correctly combine adders using the tiered + system. All Adders should be tested separately, and there is no need to test + every Adder within this class. + """ + + def test_one_operator_is_returned_unchanged(self): + op_a = linalg.LinearOperatorDiag([1., 1.]) + op_sum = add_operators([op_a]) + self.assertEqual(1, len(op_sum)) + self.assertIs(op_sum[0], op_a) + + def test_at_least_one_operators_required(self): + with self.assertRaisesRegexp(ValueError, "must contain at least one"): + add_operators([]) + + def test_attempting_to_add_numbers_raises(self): + with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"): + add_operators([1, 2]) + + def test_two_diag_operators(self): + op_a = linalg.LinearOperatorDiag( + [1., 1.], is_positive_definite=True, name="A") + op_b = linalg.LinearOperatorDiag( + [2., 2.], is_positive_definite=True, name="B") + with self.test_session(): + op_sum = add_operators([op_a, op_b]) + self.assertEqual(1, len(op_sum)) + op = op_sum[0] + self.assertIsInstance(op, linalg_lib.LinearOperatorDiag) + self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval()) + # Adding positive definite operators produces positive def. + self.assertTrue(op.is_positive_definite) + # Real diagonal ==> self-adjoint. + self.assertTrue(op.is_self_adjoint) + # Positive definite ==> non-singular + self.assertTrue(op.is_non_singular) + # Enforce particular name for this simple case + self.assertEqual("Add/B__A/", op.name) + + def test_three_diag_operators(self): + op1 = linalg.LinearOperatorDiag( + [1., 1.], is_positive_definite=True, name="op1") + op2 = linalg.LinearOperatorDiag( + [2., 2.], is_positive_definite=True, name="op2") + op3 = linalg.LinearOperatorDiag( + [3., 3.], is_positive_definite=True, name="op3") + with self.test_session(): + op_sum = add_operators([op1, op2, op3]) + self.assertEqual(1, len(op_sum)) + op = op_sum[0] + self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) + self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) + # Adding positive definite operators produces positive def. + self.assertTrue(op.is_positive_definite) + # Real diagonal ==> self-adjoint. + self.assertTrue(op.is_self_adjoint) + # Positive definite ==> non-singular + self.assertTrue(op.is_non_singular) + + def test_diag_tril_diag(self): + op1 = linalg.LinearOperatorDiag( + [1., 1.], is_non_singular=True, name="diag_a") + op2 = linalg.LinearOperatorLowerTriangular( + [[2., 0.], [0., 2.]], + is_self_adjoint=True, + is_non_singular=True, + name="tril") + op3 = linalg.LinearOperatorDiag( + [3., 3.], is_non_singular=True, name="diag_b") + with self.test_session(): + op_sum = add_operators([op1, op2, op3]) + self.assertEqual(1, len(op_sum)) + op = op_sum[0] + self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular) + self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) + + # The diag operators will be self-adjoint (because real and diagonal). + # The TriL operator has the self-adjoint hint set. + self.assertTrue(op.is_self_adjoint) + + # Even though op1/2/3 are non-singular, this does not imply op is. + # Since no custom hint was provided, we default to None (unknown). + self.assertEqual(None, op.is_non_singular) + + def test_matrix_diag_tril_diag_uses_custom_name(self): + op0 = linalg.LinearOperatorFullMatrix( + [[-1., -1.], [-1., -1.]], name="matrix") + op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") + op2 = linalg.LinearOperatorLowerTriangular( + [[2., 0.], [1.5, 2.]], name="tril") + op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") + with self.test_session(): + op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") + self.assertEqual(1, len(op_sum)) + op = op_sum[0] + self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix) + self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval()) + self.assertEqual("my_operator", op.name) + + def test_incompatible_domain_dimensions_raises(self): + op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) + op2 = linalg.LinearOperatorDiag(rng.rand(2, 4)) + with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"): + add_operators([op1, op2]) + + def test_incompatible_range_dimensions_raises(self): + op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) + op2 = linalg.LinearOperatorDiag(rng.rand(3, 3)) + with self.assertRaisesRegexp(ValueError, "must.*same range dimension"): + add_operators([op1, op2]) + + def test_non_broadcastable_batch_shape_raises(self): + op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)) + op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3)) + with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): + add_operators([op1, op2]) + + +class LinearOperatorOrderOfAdditionTest(test.TestCase): + """Test that the order of addition is done as specified by tiers.""" + + def test_tier_0_additions_done_in_tier_0(self): + diag1 = linalg.LinearOperatorDiag([1.]) + diag2 = linalg.LinearOperatorDiag([1.]) + diag3 = linalg.LinearOperatorDiag([1.]) + addition_tiers = [ + [linear_operator_addition._AddAndReturnDiag()], + [_BadAdder()], + ] + # Should not raise since all were added in tier 0, and tier 1 (with the + # _BadAdder) was never reached. + op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers) + self.assertEqual(1, len(op_sum)) + self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag) + + def test_tier_1_additions_done_by_tier_1(self): + diag1 = linalg.LinearOperatorDiag([1.]) + diag2 = linalg.LinearOperatorDiag([1.]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) + addition_tiers = [ + [linear_operator_addition._AddAndReturnDiag()], + [linear_operator_addition._AddAndReturnTriL()], + [_BadAdder()], + ] + # Should not raise since all were added by tier 1, and the + # _BadAdder) was never reached. + op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) + self.assertEqual(1, len(op_sum)) + self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular) + + def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): + diag1 = linalg.LinearOperatorDiag([1.]) + diag2 = linalg.LinearOperatorDiag([1.]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) + addition_tiers = [ + [linear_operator_addition._AddAndReturnTriL()], + [linear_operator_addition._AddAndReturnDiag()], + [_BadAdder()], + ] + # Tier 0 could convert to TriL, and this converted everything to TriL, + # including the Diags. + # Tier 1 was never used. + # Tier 2 was never used (therefore, _BadAdder didn't raise). + op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) + self.assertEqual(1, len(op_sum)) + self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular) + + def test_cannot_add_everything_so_return_more_than_one_operator(self): + diag1 = linalg.LinearOperatorDiag([1.]) + diag2 = linalg.LinearOperatorDiag([2.]) + tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) + addition_tiers = [ + [linear_operator_addition._AddAndReturnDiag()], + ] + # Tier 0 (the only tier) can only convert to Diag, so it combines the two + # diags, but the TriL is unchanged. + # Result should contain two operators, one Diag, one TriL. + op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) + self.assertEqual(2, len(op_sum)) + found_diag = False + found_tril = False + with self.test_session(): + for op in op_sum: + if isinstance(op, linalg.LinearOperatorDiag): + found_diag = True + self.assertAllClose([[3.]], op.to_dense().eval()) + if isinstance(op, linalg.LinearOperatorLowerTriangular): + found_tril = True + self.assertAllClose([[5.]], op.to_dense().eval()) + self.assertTrue(found_diag and found_tril) + + def test_intermediate_tier_is_not_skipped(self): + diag1 = linalg.LinearOperatorDiag([1.]) + diag2 = linalg.LinearOperatorDiag([1.]) + tril = linalg.LinearOperatorLowerTriangular([[1.]]) + addition_tiers = [ + [linear_operator_addition._AddAndReturnDiag()], + [_BadAdder()], + [linear_operator_addition._AddAndReturnTriL()], + ] + # tril cannot be added in tier 0, and the intermediate tier 1 with the + # BadAdder will catch it and raise. + with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): + add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) + + +class AddAndReturnScaledIdentityTest(test.TestCase): + + def setUp(self): + self._adder = linear_operator_addition._AddAndReturnScaledIdentity() + + def test_identity_plus_identity(self): + id1 = linalg.LinearOperatorIdentity(num_rows=2) + id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(id1, id2)) + operator = self._adder.add(id1, id2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) + + with self.test_session(): + self.assertAllClose(2 * + linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), + operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + def test_identity_plus_scaled_identity(self): + id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) + id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(id1, id2)) + operator = self._adder.add(id1, id2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) + + with self.test_session(): + self.assertAllClose(3.2 * + linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), + operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + def test_scaled_identity_plus_scaled_identity(self): + id1 = linalg.LinearOperatorScaledIdentity( + num_rows=2, multiplier=[2.2, 2.2, 2.2]) + id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(id1, id2)) + operator = self._adder.add(id1, id2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity) + + with self.test_session(): + self.assertAllClose(1.2 * + linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), + operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + +class AddAndReturnDiagTest(test.TestCase): + + def setUp(self): + self._adder = linear_operator_addition._AddAndReturnDiag() + + def test_identity_plus_identity_returns_diag(self): + id1 = linalg.LinearOperatorIdentity(num_rows=2) + id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(id1, id2)) + operator = self._adder.add(id1, id2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorDiag) + + with self.test_session(): + self.assertAllClose(2 * + linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), + operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + def test_diag_plus_diag(self): + diag1 = rng.rand(2, 3, 4) + diag2 = rng.rand(4) + op1 = linalg.LinearOperatorDiag(diag1) + op2 = linalg.LinearOperatorDiag(diag2) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(op1, op2)) + operator = self._adder.add(op1, op2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorDiag) + + with self.test_session(): + self.assertAllClose( + linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(), + operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + +class AddAndReturnTriLTest(test.TestCase): + + def setUp(self): + self._adder = linear_operator_addition._AddAndReturnTriL() + + def test_diag_plus_tril(self): + diag = linalg.LinearOperatorDiag([1., 2.]) + tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) + hints = linear_operator_addition._Hints( + is_positive_definite=True, is_non_singular=True) + + self.assertTrue(self._adder.can_add(diag, diag)) + self.assertTrue(self._adder.can_add(diag, tril)) + operator = self._adder.add(diag, tril, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular) + + with self.test_session(): + self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) + self.assertTrue(operator.is_positive_definite) + self.assertTrue(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + +class AddAndReturnMatrixTest(test.TestCase): + + def setUp(self): + self._adder = linear_operator_addition._AddAndReturnMatrix() + + def test_diag_plus_diag(self): + diag1 = linalg.LinearOperatorDiag([1., 2.]) + diag2 = linalg.LinearOperatorDiag([-1., 3.]) + hints = linear_operator_addition._Hints( + is_positive_definite=False, is_non_singular=False) + + self.assertTrue(self._adder.can_add(diag1, diag2)) + operator = self._adder.add(diag1, diag2, "my_operator", hints) + self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix) + + with self.test_session(): + self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval()) + self.assertFalse(operator.is_positive_definite) + self.assertFalse(operator.is_non_singular) + self.assertEqual("my_operator", operator.name) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py new file mode 100644 index 0000000000..86130a2c07 --- /dev/null +++ b/tensorflow/python/ops/linalg/linear_operator_addition.py @@ -0,0 +1,432 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Add one or more `LinearOperators` efficiently.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import abc + +import six + +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops.linalg import linear_operator +from tensorflow.python.ops.linalg import linear_operator_diag +from tensorflow.python.ops.linalg import linear_operator_full_matrix +from tensorflow.python.ops.linalg import linear_operator_identity +from tensorflow.python.ops.linalg import linear_operator_lower_triangular + +__all__ = [] + + +def add_operators(operators, + operator_name=None, + addition_tiers=None, + name=None): + """Efficiently add one or more linear operators. + + Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of + operators `[B1, B2,...]` such that + + ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` + + The operators `Bk` result by adding some of the `Ak`, as allowed by + `addition_tiers`. + + Example of efficient adding of diagonal operators. + + ```python + A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") + A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") + + # Use two tiers, the first contains an Adder that returns Diag. Since both + # A1 and A2 are Diag, they can use this Adder. The second tier will not be + # used. + addition_tiers = [ + [_AddAndReturnDiag()], + [_AddAndReturnMatrix()]] + B_list = add_operators([A1, A2], addition_tiers=addition_tiers) + + len(B_list) + ==> 1 + + B_list[0].__class__.__name__ + ==> 'LinearOperatorDiag' + + B_list[0].to_dense() + ==> [[3., 0.], + [0., 3.]] + + B_list[0].name + ==> 'Add/A1__A2/' + ``` + + Args: + operators: Iterable of `LinearOperator` objects with same `dtype`, domain + and range dimensions, and broadcastable batch shapes. + operator_name: String name for returned `LinearOperator`. Defaults to + concatenation of "Add/A__B/" that indicates the order of addition steps. + addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` + is a list of `Adder` objects. This function attempts to do all additions + in tier `i` before trying tier `i + 1`. + name: A name for this `Op`. Defaults to `add_operators`. + + Returns: + Subclass of `LinearOperator`. Class and order of addition may change as new + (and better) addition strategies emerge. + + Raises: + ValueError: If `operators` argument is empty. + ValueError: If shapes are incompatible. + """ + # Default setting + if addition_tiers is None: + addition_tiers = _DEFAULT_ADDITION_TIERS + + # Argument checking. + check_ops.assert_proper_iterable(operators) + operators = list(reversed(operators)) + if len(operators) < 1: + raise ValueError( + "Argument 'operators' must contain at least one operator. " + "Found: %s" % operators) + if not all( + isinstance(op, linear_operator.LinearOperator) for op in operators): + raise TypeError( + "Argument 'operators' must contain only LinearOperator instances. " + "Found: %s" % operators) + _static_check_for_same_dimensions(operators) + _static_check_for_broadcastable_batch_shape(operators) + + graph_parents = [] + for operator in operators: + graph_parents.extend(operator.graph_parents) + + with ops.name_scope(name or "add_operators", values=graph_parents): + + # Additions done in one of the tiers. Try tier 0, 1,... + ops_to_try_at_next_tier = list(operators) + for tier in addition_tiers: + ops_to_try_at_this_tier = ops_to_try_at_next_tier + ops_to_try_at_next_tier = [] + while ops_to_try_at_this_tier: + op1 = ops_to_try_at_this_tier.pop() + op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) + if op2 is not None: + # Will try to add the result of this again at this same tier. + new_operator = adder.add(op1, op2, operator_name) + ops_to_try_at_this_tier.append(new_operator) + else: + ops_to_try_at_next_tier.append(op1) + + return ops_to_try_at_next_tier + + +def _pop_a_match_at_tier(op1, operator_list, tier): + # Search from the back of list to the front in order to create nice default + # order of operations. + for i in range(1, len(operator_list) + 1): + op2 = operator_list[-i] + for adder in tier: + if adder.can_add(op1, op2): + return operator_list.pop(-i), adder + return None, None + + +def _infer_hints_allowing_override(op1, op2, hints): + """Infer hints from op1 and op2. hints argument is an override. + + Args: + op1: LinearOperator + op2: LinearOperator + hints: _Hints object holding "is_X" boolean hints to use for returned + operator. + If some hint is None, try to set using op1 and op2. If the + hint is provided, ignore op1 and op2 hints. This allows an override + of previous hints, but does not allow forbidden hints (e.g. you still + cannot say a real diagonal operator is not self-adjoint. + + Returns: + _Hints object. + """ + hints = hints or _Hints() + # If A, B are self-adjoint, then so is A + B. + if hints.is_self_adjoint is None: + is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint + else: + is_self_adjoint = hints.is_self_adjoint + + # If A, B are positive definite, then so is A + B. + if hints.is_positive_definite is None: + is_positive_definite = op1.is_positive_definite and op2.is_positive_definite + else: + is_positive_definite = hints.is_positive_definite + + # A positive definite operator is always non-singular. + if is_positive_definite and hints.is_positive_definite is None: + is_non_singular = True + else: + is_non_singular = hints.is_non_singular + + return _Hints( + is_non_singular=is_non_singular, + is_self_adjoint=is_self_adjoint, + is_positive_definite=is_positive_definite) + + +def _static_check_for_same_dimensions(operators): + """ValueError if operators determined to have different dimensions.""" + if len(operators) < 2: + return + + domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators + if op.domain_dimension.value is not None] + if len(set(value for name, value in domain_dimensions)) > 1: + raise ValueError("Operators must have the same domain dimension. Found: %s" + % domain_dimensions) + + range_dimensions = [(op.name, op.range_dimension.value) for op in operators + if op.range_dimension.value is not None] + if len(set(value for name, value in range_dimensions)) > 1: + raise ValueError("Operators must have the same range dimension. Found: %s" % + range_dimensions) + + +def _static_check_for_broadcastable_batch_shape(operators): + """ValueError if operators determined to have non-broadcastable shapes.""" + if len(operators) < 2: + return + + # This will fail if they cannot be broadcast together. + batch_shape = operators[0].batch_shape + for op in operators[1:]: + batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) + + +class _Hints(object): + """Holds 'is_X' flags that every LinearOperator is initialized with.""" + + def __init__(self, + is_non_singular=None, + is_positive_definite=None, + is_self_adjoint=None): + self.is_non_singular = is_non_singular + self.is_positive_definite = is_positive_definite + self.is_self_adjoint = is_self_adjoint + + +################################################################################ +# Classes to add two linear operators. +################################################################################ + + +@six.add_metaclass(abc.ABCMeta) +class _Adder(object): + """Abstract base class to add two operators. + + Each `Adder` acts independently, adding everything it can, paying no attention + as to whether another `Adder` could have done the addition more efficiently. + """ + + @property + def name(self): + return self.__class__.__name__ + + @abc.abstractmethod + def can_add(self, op1, op2): + """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" + pass + + @abc.abstractmethod + def _add(self, op1, op2, operator_name, hints): + # Derived classes can assume op1 and op2 have been validated, e.g. they have + # the same dtype, and their domain/range dimensions match. + pass + + def add(self, op1, op2, operator_name, hints=None): + """Return new `LinearOperator` acting like `op1 + op2`. + + Args: + op1: `LinearOperator` + op2: `LinearOperator`, with `shape` and `dtype` such that adding to + `op1` is allowed. + operator_name: `String` name to give to returned `LinearOperator` + hints: `_Hints` object. Returned `LinearOperator` will be created with + these hints. + + Returns: + `LinearOperator` + """ + updated_hints = _infer_hints_allowing_override(op1, op2, hints) + + if operator_name is None: + operator_name = "Add/" + op1.name + "__" + op2.name + "/" + + values = op1.graph_parents + op2.graph_parents + scope_name = self.name + if scope_name.startswith("_"): + scope_name = scope_name[1:] + with ops.name_scope(scope_name, values=values): + return self._add(op1, op2, operator_name, updated_hints) + + +class _AddAndReturnScaledIdentity(_Adder): + """Handles additions resulting in an Identity family member. + + The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family + is closed under addition. This `Adder` respects that, and returns an Identity + """ + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_IDENTITY_FAMILY) + + def _add(self, op1, op2, operator_name, hints): + # Will build a LinearOperatorScaledIdentity. + + if _type(op1) == _SCALED_IDENTITY: + multiplier_1 = op1.multiplier + else: + multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) + + if _type(op2) == _SCALED_IDENTITY: + multiplier_2 = op2.multiplier + else: + multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) + + return linear_operator_identity.LinearOperatorScaledIdentity( + num_rows=op1.range_dimension_tensor(), + multiplier=multiplier_1 + multiplier_2, + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnDiag(_Adder): + """Handles additions resulting in a Diag operator.""" + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_DIAG_LIKE) + + def _add(self, op1, op2, operator_name, hints): + return linear_operator_diag.LinearOperatorDiag( + diag=op1.diag_part() + op2.diag_part(), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnTriL(_Adder): + """Handles additions resulting in a TriL operator.""" + + def can_add(self, op1, op2): + types = {_type(op1), _type(op2)} + return not types.difference(_DIAG_LIKE.union({_TRIL})) + + def _add(self, op1, op2, operator_name, hints): + if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: + op_add_to_tensor, op_other = op1, op2 + else: + op_add_to_tensor, op_other = op2, op1 + + return linear_operator_lower_triangular.LinearOperatorLowerTriangular( + tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +class _AddAndReturnMatrix(_Adder): + """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" + + def can_add(self, op1, op2): # pylint: disable=unused-argument + return isinstance(op1, linear_operator.LinearOperator) and isinstance( + op2, linear_operator.LinearOperator) + + def _add(self, op1, op2, operator_name, hints): + if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: + op_add_to_tensor, op_other = op1, op2 + else: + op_add_to_tensor, op_other = op2, op1 + return linear_operator_full_matrix.LinearOperatorFullMatrix( + matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), + is_non_singular=hints.is_non_singular, + is_self_adjoint=hints.is_self_adjoint, + is_positive_definite=hints.is_positive_definite, + name=operator_name) + + +################################################################################ +# Constants designating types of LinearOperators +################################################################################ + +# Type name constants for LinearOperator classes. +_IDENTITY = "identity" +_SCALED_IDENTITY = "scaled_identity" +_DIAG = "diag" +_TRIL = "tril" +_MATRIX = "matrix" + +# Groups of operators. +_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} +_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} +# operators with an efficient .add_to_tensor() method. +_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE + + +def _type(operator): + """Returns the type name constant (e.g. _TRIL) for operator.""" + if isinstance(operator, linear_operator_diag.LinearOperatorDiag): + return _DIAG + if isinstance(operator, + linear_operator_lower_triangular.LinearOperatorLowerTriangular): + return _TRIL + if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): + return _MATRIX + if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): + return _IDENTITY + if isinstance(operator, + linear_operator_identity.LinearOperatorScaledIdentity): + return _SCALED_IDENTITY + raise TypeError("Operator type unknown: %s" % operator) + + +################################################################################ +# Addition tiers: +# We attempt to use Adders in tier K before K+1. +# +# Organize tiers to +# (i) reduce O(..) complexity of forming final operator, and +# (ii) produce the "most efficient" final operator. +# Dev notes: +# * Results of addition at tier K will be added at tier K or higher. +# * Tiers may change, and we warn the user that it may change. +################################################################################ + +# Note that the final tier, _AddAndReturnMatrix, will convert everything to a +# dense matrix. So it is sometimes very inefficient. +_DEFAULT_ADDITION_TIERS = [ + [_AddAndReturnScaledIdentity()], + [_AddAndReturnDiag()], + [_AddAndReturnTriL()], + [_AddAndReturnMatrix()], +] |