aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2018-09-12 09:45:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-12 09:49:50 -0700
commit26509bf4e202c09da4f0b00d43ebddf87368a0f2 (patch)
tree200e9080e06200dc5512239d25fd2f2c7b18a899
parent1c4fceab7dc09cab18c0def098320d6c52d2e514 (diff)
Add linear_operator_addition to tensorflow/python/. A subsequent CL
will remove this from contrib. linear_operator_addition is hidden from the public API. PiperOrigin-RevId: 212655087
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD16
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py412
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_addition.py432
3 files changed, 860 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index f4ec3e3996..be2e31cb5a 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -25,6 +25,22 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_addition_test",
+ size = "small",
+ srcs = ["linear_operator_addition_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+cuda_py_test(
name = "linear_operator_block_diag_test",
size = "medium",
srcs = ["linear_operator_block_diag_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
new file mode 100644
index 0000000000..7c79fedf65
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_addition_test.py
@@ -0,0 +1,412 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import random_seed
+from tensorflow.python.ops import linalg_ops
+from tensorflow.python.ops.linalg import linalg as linalg_lib
+from tensorflow.python.ops.linalg import linear_operator_addition
+from tensorflow.python.platform import test
+
+linalg = linalg_lib
+random_seed.set_random_seed(23)
+rng = np.random.RandomState(0)
+
+add_operators = linear_operator_addition.add_operators
+
+
+# pylint: disable=unused-argument
+class _BadAdder(linear_operator_addition._Adder):
+ """Adder that will fail if used."""
+
+ def can_add(self, op1, op2):
+ raise AssertionError("BadAdder.can_add called!")
+
+ def _add(self, op1, op2, operator_name, hints):
+ raise AssertionError("This line should not be reached")
+
+
+# pylint: enable=unused-argument
+
+
+class LinearOperatorAdditionCorrectnessTest(test.TestCase):
+ """Tests correctness of addition with combinations of a few Adders.
+
+ Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
+ add_operators should reduce all operators resulting in one single operator.
+
+ This shows that we are able to correctly combine adders using the tiered
+ system. All Adders should be tested separately, and there is no need to test
+ every Adder within this class.
+ """
+
+ def test_one_operator_is_returned_unchanged(self):
+ op_a = linalg.LinearOperatorDiag([1., 1.])
+ op_sum = add_operators([op_a])
+ self.assertEqual(1, len(op_sum))
+ self.assertIs(op_sum[0], op_a)
+
+ def test_at_least_one_operators_required(self):
+ with self.assertRaisesRegexp(ValueError, "must contain at least one"):
+ add_operators([])
+
+ def test_attempting_to_add_numbers_raises(self):
+ with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
+ add_operators([1, 2])
+
+ def test_two_diag_operators(self):
+ op_a = linalg.LinearOperatorDiag(
+ [1., 1.], is_positive_definite=True, name="A")
+ op_b = linalg.LinearOperatorDiag(
+ [2., 2.], is_positive_definite=True, name="B")
+ with self.test_session():
+ op_sum = add_operators([op_a, op_b])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorDiag)
+ self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
+ # Adding positive definite operators produces positive def.
+ self.assertTrue(op.is_positive_definite)
+ # Real diagonal ==> self-adjoint.
+ self.assertTrue(op.is_self_adjoint)
+ # Positive definite ==> non-singular
+ self.assertTrue(op.is_non_singular)
+ # Enforce particular name for this simple case
+ self.assertEqual("Add/B__A/", op.name)
+
+ def test_three_diag_operators(self):
+ op1 = linalg.LinearOperatorDiag(
+ [1., 1.], is_positive_definite=True, name="op1")
+ op2 = linalg.LinearOperatorDiag(
+ [2., 2.], is_positive_definite=True, name="op2")
+ op3 = linalg.LinearOperatorDiag(
+ [3., 3.], is_positive_definite=True, name="op3")
+ with self.test_session():
+ op_sum = add_operators([op1, op2, op3])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
+ self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+ # Adding positive definite operators produces positive def.
+ self.assertTrue(op.is_positive_definite)
+ # Real diagonal ==> self-adjoint.
+ self.assertTrue(op.is_self_adjoint)
+ # Positive definite ==> non-singular
+ self.assertTrue(op.is_non_singular)
+
+ def test_diag_tril_diag(self):
+ op1 = linalg.LinearOperatorDiag(
+ [1., 1.], is_non_singular=True, name="diag_a")
+ op2 = linalg.LinearOperatorLowerTriangular(
+ [[2., 0.], [0., 2.]],
+ is_self_adjoint=True,
+ is_non_singular=True,
+ name="tril")
+ op3 = linalg.LinearOperatorDiag(
+ [3., 3.], is_non_singular=True, name="diag_b")
+ with self.test_session():
+ op_sum = add_operators([op1, op2, op3])
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorLowerTriangular)
+ self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
+
+ # The diag operators will be self-adjoint (because real and diagonal).
+ # The TriL operator has the self-adjoint hint set.
+ self.assertTrue(op.is_self_adjoint)
+
+ # Even though op1/2/3 are non-singular, this does not imply op is.
+ # Since no custom hint was provided, we default to None (unknown).
+ self.assertEqual(None, op.is_non_singular)
+
+ def test_matrix_diag_tril_diag_uses_custom_name(self):
+ op0 = linalg.LinearOperatorFullMatrix(
+ [[-1., -1.], [-1., -1.]], name="matrix")
+ op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
+ op2 = linalg.LinearOperatorLowerTriangular(
+ [[2., 0.], [1.5, 2.]], name="tril")
+ op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
+ with self.test_session():
+ op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
+ self.assertEqual(1, len(op_sum))
+ op = op_sum[0]
+ self.assertIsInstance(op, linalg_lib.LinearOperatorFullMatrix)
+ self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
+ self.assertEqual("my_operator", op.name)
+
+ def test_incompatible_domain_dimensions_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
+ with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
+ add_operators([op1, op2])
+
+ def test_incompatible_range_dimensions_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
+ with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
+ add_operators([op1, op2])
+
+ def test_non_broadcastable_batch_shape_raises(self):
+ op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
+ op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
+ with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
+ add_operators([op1, op2])
+
+
+class LinearOperatorOrderOfAdditionTest(test.TestCase):
+ """Test that the order of addition is done as specified by tiers."""
+
+ def test_tier_0_additions_done_in_tier_0(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ diag3 = linalg.LinearOperatorDiag([1.])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ ]
+ # Should not raise since all were added in tier 0, and tier 1 (with the
+ # _BadAdder) was never reached.
+ op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorDiag)
+
+ def test_tier_1_additions_done_by_tier_1(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [linear_operator_addition._AddAndReturnTriL()],
+ [_BadAdder()],
+ ]
+ # Should not raise since all were added by tier 1, and the
+ # _BadAdder) was never reached.
+ op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+ def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnTriL()],
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ ]
+ # Tier 0 could convert to TriL, and this converted everything to TriL,
+ # including the Diags.
+ # Tier 1 was never used.
+ # Tier 2 was never used (therefore, _BadAdder didn't raise).
+ op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+ self.assertEqual(1, len(op_sum))
+ self.assertIsInstance(op_sum[0], linalg.LinearOperatorLowerTriangular)
+
+ def test_cannot_add_everything_so_return_more_than_one_operator(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([2.])
+ tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ ]
+ # Tier 0 (the only tier) can only convert to Diag, so it combines the two
+ # diags, but the TriL is unchanged.
+ # Result should contain two operators, one Diag, one TriL.
+ op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
+ self.assertEqual(2, len(op_sum))
+ found_diag = False
+ found_tril = False
+ with self.test_session():
+ for op in op_sum:
+ if isinstance(op, linalg.LinearOperatorDiag):
+ found_diag = True
+ self.assertAllClose([[3.]], op.to_dense().eval())
+ if isinstance(op, linalg.LinearOperatorLowerTriangular):
+ found_tril = True
+ self.assertAllClose([[5.]], op.to_dense().eval())
+ self.assertTrue(found_diag and found_tril)
+
+ def test_intermediate_tier_is_not_skipped(self):
+ diag1 = linalg.LinearOperatorDiag([1.])
+ diag2 = linalg.LinearOperatorDiag([1.])
+ tril = linalg.LinearOperatorLowerTriangular([[1.]])
+ addition_tiers = [
+ [linear_operator_addition._AddAndReturnDiag()],
+ [_BadAdder()],
+ [linear_operator_addition._AddAndReturnTriL()],
+ ]
+ # tril cannot be added in tier 0, and the intermediate tier 1 with the
+ # BadAdder will catch it and raise.
+ with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
+ add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
+
+
+class AddAndReturnScaledIdentityTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
+
+ def test_identity_plus_identity(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2)
+ id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_identity_plus_scaled_identity(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(3.2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_scaled_identity_plus_scaled_identity(self):
+ id1 = linalg.LinearOperatorScaledIdentity(
+ num_rows=2, multiplier=[2.2, 2.2, 2.2])
+ id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorScaledIdentity)
+
+ with self.test_session():
+ self.assertAllClose(1.2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnDiagTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnDiag()
+
+ def test_identity_plus_identity_returns_diag(self):
+ id1 = linalg.LinearOperatorIdentity(num_rows=2)
+ id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(id1, id2))
+ operator = self._adder.add(id1, id2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+ with self.test_session():
+ self.assertAllClose(2 *
+ linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+ def test_diag_plus_diag(self):
+ diag1 = rng.rand(2, 3, 4)
+ diag2 = rng.rand(4)
+ op1 = linalg.LinearOperatorDiag(diag1)
+ op2 = linalg.LinearOperatorDiag(diag2)
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(op1, op2))
+ operator = self._adder.add(op1, op2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorDiag)
+
+ with self.test_session():
+ self.assertAllClose(
+ linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
+ operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnTriLTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnTriL()
+
+ def test_diag_plus_tril(self):
+ diag = linalg.LinearOperatorDiag([1., 2.])
+ tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=True, is_non_singular=True)
+
+ self.assertTrue(self._adder.can_add(diag, diag))
+ self.assertTrue(self._adder.can_add(diag, tril))
+ operator = self._adder.add(diag, tril, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorLowerTriangular)
+
+ with self.test_session():
+ self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
+ self.assertTrue(operator.is_positive_definite)
+ self.assertTrue(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+class AddAndReturnMatrixTest(test.TestCase):
+
+ def setUp(self):
+ self._adder = linear_operator_addition._AddAndReturnMatrix()
+
+ def test_diag_plus_diag(self):
+ diag1 = linalg.LinearOperatorDiag([1., 2.])
+ diag2 = linalg.LinearOperatorDiag([-1., 3.])
+ hints = linear_operator_addition._Hints(
+ is_positive_definite=False, is_non_singular=False)
+
+ self.assertTrue(self._adder.can_add(diag1, diag2))
+ operator = self._adder.add(diag1, diag2, "my_operator", hints)
+ self.assertIsInstance(operator, linalg.LinearOperatorFullMatrix)
+
+ with self.test_session():
+ self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
+ self.assertFalse(operator.is_positive_definite)
+ self.assertFalse(operator.is_non_singular)
+ self.assertEqual("my_operator", operator.name)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/ops/linalg/linear_operator_addition.py b/tensorflow/python/ops/linalg/linear_operator_addition.py
new file mode 100644
index 0000000000..86130a2c07
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_addition.py
@@ -0,0 +1,432 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Add one or more `LinearOperators` efficiently."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_diag
+from tensorflow.python.ops.linalg import linear_operator_full_matrix
+from tensorflow.python.ops.linalg import linear_operator_identity
+from tensorflow.python.ops.linalg import linear_operator_lower_triangular
+
+__all__ = []
+
+
+def add_operators(operators,
+ operator_name=None,
+ addition_tiers=None,
+ name=None):
+ """Efficiently add one or more linear operators.
+
+ Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
+ operators `[B1, B2,...]` such that
+
+ ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
+
+ The operators `Bk` result by adding some of the `Ak`, as allowed by
+ `addition_tiers`.
+
+ Example of efficient adding of diagonal operators.
+
+ ```python
+ A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
+ A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
+
+ # Use two tiers, the first contains an Adder that returns Diag. Since both
+ # A1 and A2 are Diag, they can use this Adder. The second tier will not be
+ # used.
+ addition_tiers = [
+ [_AddAndReturnDiag()],
+ [_AddAndReturnMatrix()]]
+ B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
+
+ len(B_list)
+ ==> 1
+
+ B_list[0].__class__.__name__
+ ==> 'LinearOperatorDiag'
+
+ B_list[0].to_dense()
+ ==> [[3., 0.],
+ [0., 3.]]
+
+ B_list[0].name
+ ==> 'Add/A1__A2/'
+ ```
+
+ Args:
+ operators: Iterable of `LinearOperator` objects with same `dtype`, domain
+ and range dimensions, and broadcastable batch shapes.
+ operator_name: String name for returned `LinearOperator`. Defaults to
+ concatenation of "Add/A__B/" that indicates the order of addition steps.
+ addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
+ is a list of `Adder` objects. This function attempts to do all additions
+ in tier `i` before trying tier `i + 1`.
+ name: A name for this `Op`. Defaults to `add_operators`.
+
+ Returns:
+ Subclass of `LinearOperator`. Class and order of addition may change as new
+ (and better) addition strategies emerge.
+
+ Raises:
+ ValueError: If `operators` argument is empty.
+ ValueError: If shapes are incompatible.
+ """
+ # Default setting
+ if addition_tiers is None:
+ addition_tiers = _DEFAULT_ADDITION_TIERS
+
+ # Argument checking.
+ check_ops.assert_proper_iterable(operators)
+ operators = list(reversed(operators))
+ if len(operators) < 1:
+ raise ValueError(
+ "Argument 'operators' must contain at least one operator. "
+ "Found: %s" % operators)
+ if not all(
+ isinstance(op, linear_operator.LinearOperator) for op in operators):
+ raise TypeError(
+ "Argument 'operators' must contain only LinearOperator instances. "
+ "Found: %s" % operators)
+ _static_check_for_same_dimensions(operators)
+ _static_check_for_broadcastable_batch_shape(operators)
+
+ graph_parents = []
+ for operator in operators:
+ graph_parents.extend(operator.graph_parents)
+
+ with ops.name_scope(name or "add_operators", values=graph_parents):
+
+ # Additions done in one of the tiers. Try tier 0, 1,...
+ ops_to_try_at_next_tier = list(operators)
+ for tier in addition_tiers:
+ ops_to_try_at_this_tier = ops_to_try_at_next_tier
+ ops_to_try_at_next_tier = []
+ while ops_to_try_at_this_tier:
+ op1 = ops_to_try_at_this_tier.pop()
+ op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
+ if op2 is not None:
+ # Will try to add the result of this again at this same tier.
+ new_operator = adder.add(op1, op2, operator_name)
+ ops_to_try_at_this_tier.append(new_operator)
+ else:
+ ops_to_try_at_next_tier.append(op1)
+
+ return ops_to_try_at_next_tier
+
+
+def _pop_a_match_at_tier(op1, operator_list, tier):
+ # Search from the back of list to the front in order to create nice default
+ # order of operations.
+ for i in range(1, len(operator_list) + 1):
+ op2 = operator_list[-i]
+ for adder in tier:
+ if adder.can_add(op1, op2):
+ return operator_list.pop(-i), adder
+ return None, None
+
+
+def _infer_hints_allowing_override(op1, op2, hints):
+ """Infer hints from op1 and op2. hints argument is an override.
+
+ Args:
+ op1: LinearOperator
+ op2: LinearOperator
+ hints: _Hints object holding "is_X" boolean hints to use for returned
+ operator.
+ If some hint is None, try to set using op1 and op2. If the
+ hint is provided, ignore op1 and op2 hints. This allows an override
+ of previous hints, but does not allow forbidden hints (e.g. you still
+ cannot say a real diagonal operator is not self-adjoint.
+
+ Returns:
+ _Hints object.
+ """
+ hints = hints or _Hints()
+ # If A, B are self-adjoint, then so is A + B.
+ if hints.is_self_adjoint is None:
+ is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
+ else:
+ is_self_adjoint = hints.is_self_adjoint
+
+ # If A, B are positive definite, then so is A + B.
+ if hints.is_positive_definite is None:
+ is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
+ else:
+ is_positive_definite = hints.is_positive_definite
+
+ # A positive definite operator is always non-singular.
+ if is_positive_definite and hints.is_positive_definite is None:
+ is_non_singular = True
+ else:
+ is_non_singular = hints.is_non_singular
+
+ return _Hints(
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite)
+
+
+def _static_check_for_same_dimensions(operators):
+ """ValueError if operators determined to have different dimensions."""
+ if len(operators) < 2:
+ return
+
+ domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
+ if op.domain_dimension.value is not None]
+ if len(set(value for name, value in domain_dimensions)) > 1:
+ raise ValueError("Operators must have the same domain dimension. Found: %s"
+ % domain_dimensions)
+
+ range_dimensions = [(op.name, op.range_dimension.value) for op in operators
+ if op.range_dimension.value is not None]
+ if len(set(value for name, value in range_dimensions)) > 1:
+ raise ValueError("Operators must have the same range dimension. Found: %s" %
+ range_dimensions)
+
+
+def _static_check_for_broadcastable_batch_shape(operators):
+ """ValueError if operators determined to have non-broadcastable shapes."""
+ if len(operators) < 2:
+ return
+
+ # This will fail if they cannot be broadcast together.
+ batch_shape = operators[0].batch_shape
+ for op in operators[1:]:
+ batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
+
+
+class _Hints(object):
+ """Holds 'is_X' flags that every LinearOperator is initialized with."""
+
+ def __init__(self,
+ is_non_singular=None,
+ is_positive_definite=None,
+ is_self_adjoint=None):
+ self.is_non_singular = is_non_singular
+ self.is_positive_definite = is_positive_definite
+ self.is_self_adjoint = is_self_adjoint
+
+
+################################################################################
+# Classes to add two linear operators.
+################################################################################
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _Adder(object):
+ """Abstract base class to add two operators.
+
+ Each `Adder` acts independently, adding everything it can, paying no attention
+ as to whether another `Adder` could have done the addition more efficiently.
+ """
+
+ @property
+ def name(self):
+ return self.__class__.__name__
+
+ @abc.abstractmethod
+ def can_add(self, op1, op2):
+ """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
+ pass
+
+ @abc.abstractmethod
+ def _add(self, op1, op2, operator_name, hints):
+ # Derived classes can assume op1 and op2 have been validated, e.g. they have
+ # the same dtype, and their domain/range dimensions match.
+ pass
+
+ def add(self, op1, op2, operator_name, hints=None):
+ """Return new `LinearOperator` acting like `op1 + op2`.
+
+ Args:
+ op1: `LinearOperator`
+ op2: `LinearOperator`, with `shape` and `dtype` such that adding to
+ `op1` is allowed.
+ operator_name: `String` name to give to returned `LinearOperator`
+ hints: `_Hints` object. Returned `LinearOperator` will be created with
+ these hints.
+
+ Returns:
+ `LinearOperator`
+ """
+ updated_hints = _infer_hints_allowing_override(op1, op2, hints)
+
+ if operator_name is None:
+ operator_name = "Add/" + op1.name + "__" + op2.name + "/"
+
+ values = op1.graph_parents + op2.graph_parents
+ scope_name = self.name
+ if scope_name.startswith("_"):
+ scope_name = scope_name[1:]
+ with ops.name_scope(scope_name, values=values):
+ return self._add(op1, op2, operator_name, updated_hints)
+
+
+class _AddAndReturnScaledIdentity(_Adder):
+ """Handles additions resulting in an Identity family member.
+
+ The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
+ is closed under addition. This `Adder` respects that, and returns an Identity
+ """
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_IDENTITY_FAMILY)
+
+ def _add(self, op1, op2, operator_name, hints):
+ # Will build a LinearOperatorScaledIdentity.
+
+ if _type(op1) == _SCALED_IDENTITY:
+ multiplier_1 = op1.multiplier
+ else:
+ multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
+
+ if _type(op2) == _SCALED_IDENTITY:
+ multiplier_2 = op2.multiplier
+ else:
+ multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
+
+ return linear_operator_identity.LinearOperatorScaledIdentity(
+ num_rows=op1.range_dimension_tensor(),
+ multiplier=multiplier_1 + multiplier_2,
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnDiag(_Adder):
+ """Handles additions resulting in a Diag operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE)
+
+ def _add(self, op1, op2, operator_name, hints):
+ return linear_operator_diag.LinearOperatorDiag(
+ diag=op1.diag_part() + op2.diag_part(),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnTriL(_Adder):
+ """Handles additions resulting in a TriL operator."""
+
+ def can_add(self, op1, op2):
+ types = {_type(op1), _type(op2)}
+ return not types.difference(_DIAG_LIKE.union({_TRIL}))
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+
+ return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
+ tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+class _AddAndReturnMatrix(_Adder):
+ """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
+
+ def can_add(self, op1, op2): # pylint: disable=unused-argument
+ return isinstance(op1, linear_operator.LinearOperator) and isinstance(
+ op2, linear_operator.LinearOperator)
+
+ def _add(self, op1, op2, operator_name, hints):
+ if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
+ op_add_to_tensor, op_other = op1, op2
+ else:
+ op_add_to_tensor, op_other = op2, op1
+ return linear_operator_full_matrix.LinearOperatorFullMatrix(
+ matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
+ is_non_singular=hints.is_non_singular,
+ is_self_adjoint=hints.is_self_adjoint,
+ is_positive_definite=hints.is_positive_definite,
+ name=operator_name)
+
+
+################################################################################
+# Constants designating types of LinearOperators
+################################################################################
+
+# Type name constants for LinearOperator classes.
+_IDENTITY = "identity"
+_SCALED_IDENTITY = "scaled_identity"
+_DIAG = "diag"
+_TRIL = "tril"
+_MATRIX = "matrix"
+
+# Groups of operators.
+_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
+_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
+# operators with an efficient .add_to_tensor() method.
+_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
+
+
+def _type(operator):
+ """Returns the type name constant (e.g. _TRIL) for operator."""
+ if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
+ return _DIAG
+ if isinstance(operator,
+ linear_operator_lower_triangular.LinearOperatorLowerTriangular):
+ return _TRIL
+ if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
+ return _MATRIX
+ if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
+ return _IDENTITY
+ if isinstance(operator,
+ linear_operator_identity.LinearOperatorScaledIdentity):
+ return _SCALED_IDENTITY
+ raise TypeError("Operator type unknown: %s" % operator)
+
+
+################################################################################
+# Addition tiers:
+# We attempt to use Adders in tier K before K+1.
+#
+# Organize tiers to
+# (i) reduce O(..) complexity of forming final operator, and
+# (ii) produce the "most efficient" final operator.
+# Dev notes:
+# * Results of addition at tier K will be added at tier K or higher.
+# * Tiers may change, and we warn the user that it may change.
+################################################################################
+
+# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
+# dense matrix. So it is sometimes very inefficient.
+_DEFAULT_ADDITION_TIERS = [
+ [_AddAndReturnScaledIdentity()],
+ [_AddAndReturnDiag()],
+ [_AddAndReturnTriL()],
+ [_AddAndReturnMatrix()],
+]