From d5f4c3aa59aebc88f42a186a30ef6200857194ca Mon Sep 17 00:00:00 2001 From: Ian Langmore Date: Mon, 17 Sep 2018 15:46:30 -0700 Subject: Remove tensorflow/contrib/linalg library. linalg remains in core. PiperOrigin-RevId: 213352573 --- CODEOWNERS | 1 - tensorflow/contrib/BUILD | 1 - tensorflow/contrib/__init__.py | 1 - tensorflow/contrib/cmake/python_modules.txt | 3 - tensorflow/contrib/cmake/tf_tests.cmake | 1 - tensorflow/contrib/distributions/BUILD | 54 ++- tensorflow/contrib/linalg/BUILD | 44 --- tensorflow/contrib/linalg/__init__.py | 58 --- tensorflow/contrib/linalg/python/__init__.py | 19 - .../kernel_tests/linear_operator_addition_test.py | 412 -------------------- .../linalg/python/ops/linear_operator_addition.py | 432 --------------------- 11 files changed, 26 insertions(+), 1000 deletions(-) delete mode 100644 tensorflow/contrib/linalg/BUILD delete mode 100644 tensorflow/contrib/linalg/__init__.py delete mode 100644 tensorflow/contrib/linalg/python/__init__.py delete mode 100644 tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py delete mode 100644 tensorflow/contrib/linalg/python/ops/linear_operator_addition.py diff --git a/CODEOWNERS b/CODEOWNERS index b612bccffb..94cc865479 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -41,7 +41,6 @@ /tensorflow/contrib/labeled_tensor/ @shoyer /tensorflow/contrib/layers/ @fchollet @martinwicke /tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp -/tensorflow/contrib/linalg/ @langmore /tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis /tensorflow/contrib/lookup/ @ysuematsu @andreasst /tensorflow/contrib/losses/ @alextp @ispirmustafa diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index d98a24994c..e1af52cd96 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -60,7 +60,6 @@ py_library( "//tensorflow/contrib/learn", "//tensorflow/contrib/legacy_seq2seq:seq2seq_py", "//tensorflow/contrib/libsvm", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/contrib/linear_optimizer:sdca_estimator_py", "//tensorflow/contrib/linear_optimizer:sdca_ops_py", "//tensorflow/contrib/lite/python:lite", diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py index 9478e42b46..e71b0e0ae3 100644 --- a/tensorflow/contrib/__init__.py +++ b/tensorflow/contrib/__init__.py @@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor from tensorflow.contrib import layers from tensorflow.contrib import learn from tensorflow.contrib import legacy_seq2seq -from tensorflow.contrib import linalg from tensorflow.contrib import linear_optimizer from tensorflow.contrib import lookup from tensorflow.contrib import losses diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt index fb871acae9..1c432b6e0b 100644 --- a/tensorflow/contrib/cmake/python_modules.txt +++ b/tensorflow/contrib/cmake/python_modules.txt @@ -273,9 +273,6 @@ tensorflow/contrib/libsvm tensorflow/contrib/libsvm/python tensorflow/contrib/libsvm/python/kernel_tests tensorflow/contrib/libsvm/python/ops -tensorflow/contrib/linalg -tensorflow/contrib/linalg/python -tensorflow/contrib/linalg/python/ops tensorflow/contrib/linear_optimizer tensorflow/contrib/linear_optimizer/kernels tensorflow/contrib/linear_optimizer/kernels/g3doc diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake index 2c878c1716..ed31351d9e 100644 --- a/tensorflow/contrib/cmake/tf_tests.cmake +++ b/tensorflow/contrib/cmake/tf_tests.cmake @@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS) file(GLOB_RECURSE tf_test_src_py ${tf_test_src_py} "${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py" - "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py" "${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py" diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index 9aadc634da..3ff7da4f89 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -25,7 +25,6 @@ py_library( "`tf.contrib.distributions` to `tfp.distributions`."), srcs_version = "PY2AND3", deps = [ - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:clip_ops", @@ -61,7 +60,6 @@ py_library( ":bijectors_py", "//tensorflow/contrib/framework:framework_py", "//tensorflow/contrib/learn", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:check_ops", "//tensorflow/python:control_flow_ops", @@ -706,8 +704,8 @@ cuda_py_test( ":bijectors_py", ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", @@ -722,8 +720,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:client_testlib", + "//tensorflow/python/ops/linalg", ], shard_count = 4, tags = ["noasan"], # times out, http://b/78588814 @@ -739,8 +737,8 @@ cuda_py_test( additional_deps = [ ":distributions_py", "//third_party/py/numpy", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:math_ops", @@ -794,8 +792,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -831,8 +829,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -852,8 +850,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -871,8 +869,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -907,8 +905,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -926,10 +924,10 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python/ops/linalg", "//tensorflow/python:framework_test_lib", "//tensorflow/python:math_ops", "//tensorflow/python:platform_test", @@ -945,8 +943,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -964,8 +962,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -983,8 +981,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1002,8 +1000,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1021,8 +1019,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1040,8 +1038,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1075,8 +1073,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1126,8 +1124,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1161,8 +1159,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1180,8 +1178,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1201,8 +1199,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1221,8 +1219,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1240,8 +1238,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1259,8 +1257,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1278,8 +1276,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1297,8 +1295,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", @@ -1316,8 +1314,8 @@ cuda_py_test( ":distributions_py", "//third_party/py/numpy", "@six_archive//:six", - "//tensorflow/contrib/linalg:linalg_py", "//tensorflow/python:array_ops", + "//tensorflow/python/ops/linalg", "//tensorflow/python:client_testlib", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:framework_test_lib", diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD deleted file mode 100644 index 78b7970069..0000000000 --- a/tensorflow/contrib/linalg/BUILD +++ /dev/null @@ -1,44 +0,0 @@ -# Description: -# Contains classes that provide access to common method of a [batch] matrix, -# without the need to instantiate the matrix. -# This allows for exploitation of structure, as well as a generic interface -# suitable for iterative solvers. - -licenses(["notice"]) # Apache 2.0 - -exports_files(["LICENSE"]) - -package(default_visibility = ["//tensorflow:__subpackages__"]) - -load("//tensorflow:tensorflow.bzl", "cuda_py_test") - -py_library( - name = "linalg_py", - srcs = ["__init__.py"] + glob(["python/ops/*.py"]), - srcs_version = "PY2AND3", - deps = [ - "//tensorflow/python:array_ops", - "//tensorflow/python:check_ops", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:util", - "//tensorflow/python/ops/linalg", - "@six_archive//:six", - ], -) - -cuda_py_test( - name = "linear_operator_addition_test", - size = "small", - srcs = ["python/kernel_tests/linear_operator_addition_test.py"], - additional_deps = [ - ":linalg_py", - "//third_party/py/numpy", - "//tensorflow/python:array_ops", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework", - "//tensorflow/python:framework_for_generated_wrappers", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python:math_ops", - "//tensorflow/python:platform_test", - ], -) diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py deleted file mode 100644 index cbe4c03e4d..0000000000 --- a/tensorflow/contrib/linalg/__init__.py +++ /dev/null @@ -1,58 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Linear algebra libraries. - -See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg) -guide. - -@@LinearOperator -@@LinearOperatorBlockDiag -@@LinearOperatorCirculant -@@LinearOperatorCirculant2D -@@LinearOperatorCirculant3D -@@LinearOperatorDiag -@@LinearOperatorIdentity -@@LinearOperatorScaledIdentity -@@LinearOperatorFullMatrix -@@LinearOperatorKronecker -@@LinearOperatorLowerTriangular -@@LinearOperatorLowRankUpdate -@@LinearOperatorComposition -@@add_operators - -""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member - -from tensorflow.contrib.linalg.python.ops.linear_operator_addition import * -from tensorflow.python.ops.linalg.linear_operator import * -from tensorflow.python.ops.linalg.linear_operator_block_diag import * -from tensorflow.python.ops.linalg.linear_operator_circulant import * -from tensorflow.python.ops.linalg.linear_operator_composition import * -from tensorflow.python.ops.linalg.linear_operator_diag import * -from tensorflow.python.ops.linalg.linear_operator_full_matrix import * -from tensorflow.python.ops.linalg.linear_operator_identity import * -from tensorflow.python.ops.linalg.linear_operator_kronecker import * -from tensorflow.python.ops.linalg.linear_operator_low_rank_update import * -from tensorflow.python.ops.linalg.linear_operator_lower_triangular import * - -# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member - -from tensorflow.python.util.all_util import remove_undocumented - -remove_undocumented(__name__) diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/contrib/linalg/python/__init__.py deleted file mode 100644 index c5ca3a623f..0000000000 --- a/tensorflow/contrib/linalg/python/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""ops module.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py deleted file mode 100644 index d94ac73654..0000000000 --- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py +++ /dev/null @@ -1,412 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import numpy as np - -from tensorflow.contrib.linalg.python.ops import linear_operator_addition -from tensorflow.python.framework import random_seed -from tensorflow.python.ops import linalg_ops -from tensorflow.python.ops.linalg import linalg as linalg_lib -from tensorflow.python.platform import test - -linalg = linalg_lib -random_seed.set_random_seed(23) -rng = np.random.RandomState(0) - -add_operators = linear_operator_addition.add_operators - - -# pylint: disable=unused-argument -class _BadAdder(linear_operator_addition._Adder): - """Adder that will fail if used.""" - - def can_add(self, op1, op2): - raise AssertionError("BadAdder.can_add called!") - - def _add(self, op1, op2, operator_name, hints): - raise AssertionError("This line should not be reached") - - -# pylint: enable=unused-argument - - -class LinearOperatorAdditionCorrectnessTest(test.TestCase): - """Tests correctness of addition with combinations of a few Adders. - - Tests here are done with the _DEFAULT_ADDITION_TIERS, which means - add_operators should reduce all operators resulting in one single operator. - - This shows that we are able to correctly combine adders using the tiered - system. All Adders should be tested separately, and there is no need to test - every Adder within this class. - """ - - def test_one_operator_is_returned_unchanged(self): - op_a = linalg.LinearOperatorDiag([1., 1.]) - op_sum = add_operators([op_a]) - self.assertEqual(1, len(op_sum)) - self.assertTrue(op_sum[0] is op_a) - - def test_at_least_one_operators_required(self): - with self.assertRaisesRegexp(ValueError, "must contain at least one"): - add_operators([]) - - def test_attempting_to_add_numbers_raises(self): - with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"): - add_operators([1, 2]) - - def test_two_diag_operators(self): - op_a = linalg.LinearOperatorDiag( - [1., 1.], is_positive_definite=True, name="A") - op_b = linalg.LinearOperatorDiag( - [2., 2.], is_positive_definite=True, name="B") - with self.cached_session(): - op_sum = add_operators([op_a, op_b]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) - self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval()) - # Adding positive definite operators produces positive def. - self.assertTrue(op.is_positive_definite) - # Real diagonal ==> self-adjoint. - self.assertTrue(op.is_self_adjoint) - # Positive definite ==> non-singular - self.assertTrue(op.is_non_singular) - # Enforce particular name for this simple case - self.assertEqual("Add/B__A/", op.name) - - def test_three_diag_operators(self): - op1 = linalg.LinearOperatorDiag( - [1., 1.], is_positive_definite=True, name="op1") - op2 = linalg.LinearOperatorDiag( - [2., 2.], is_positive_definite=True, name="op2") - op3 = linalg.LinearOperatorDiag( - [3., 3.], is_positive_definite=True, name="op3") - with self.cached_session(): - op_sum = add_operators([op1, op2, op3]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag)) - self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) - # Adding positive definite operators produces positive def. - self.assertTrue(op.is_positive_definite) - # Real diagonal ==> self-adjoint. - self.assertTrue(op.is_self_adjoint) - # Positive definite ==> non-singular - self.assertTrue(op.is_non_singular) - - def test_diag_tril_diag(self): - op1 = linalg.LinearOperatorDiag( - [1., 1.], is_non_singular=True, name="diag_a") - op2 = linalg.LinearOperatorLowerTriangular( - [[2., 0.], [0., 2.]], - is_self_adjoint=True, - is_non_singular=True, - name="tril") - op3 = linalg.LinearOperatorDiag( - [3., 3.], is_non_singular=True, name="diag_b") - with self.cached_session(): - op_sum = add_operators([op1, op2, op3]) - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular)) - self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval()) - - # The diag operators will be self-adjoint (because real and diagonal). - # The TriL operator has the self-adjoint hint set. - self.assertTrue(op.is_self_adjoint) - - # Even though op1/2/3 are non-singular, this does not imply op is. - # Since no custom hint was provided, we default to None (unknown). - self.assertEqual(None, op.is_non_singular) - - def test_matrix_diag_tril_diag_uses_custom_name(self): - op0 = linalg.LinearOperatorFullMatrix( - [[-1., -1.], [-1., -1.]], name="matrix") - op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a") - op2 = linalg.LinearOperatorLowerTriangular( - [[2., 0.], [1.5, 2.]], name="tril") - op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b") - with self.cached_session(): - op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator") - self.assertEqual(1, len(op_sum)) - op = op_sum[0] - self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix)) - self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval()) - self.assertEqual("my_operator", op.name) - - def test_incompatible_domain_dimensions_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(2, 4)) - with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"): - add_operators([op1, op2]) - - def test_incompatible_range_dimensions_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(3, 3)) - with self.assertRaisesRegexp(ValueError, "must.*same range dimension"): - add_operators([op1, op2]) - - def test_non_broadcastable_batch_shape_raises(self): - op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3)) - op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3)) - with self.assertRaisesRegexp(ValueError, "Incompatible shapes"): - add_operators([op1, op2]) - - -class LinearOperatorOrderOfAdditionTest(test.TestCase): - """Test that the order of addition is done as specified by tiers.""" - - def test_tier_0_additions_done_in_tier_0(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - diag3 = linalg.LinearOperatorDiag([1.]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - ] - # Should not raise since all were added in tier 0, and tier 1 (with the - # _BadAdder) was never reached. - op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag)) - - def test_tier_1_additions_done_by_tier_1(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [linear_operator_addition._AddAndReturnTriL()], - [_BadAdder()], - ] - # Should not raise since all were added by tier 1, and the - # _BadAdder) was never reached. - op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) - - def test_tier_1_additions_done_by_tier_1_with_order_flipped(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnTriL()], - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - ] - # Tier 0 could convert to TriL, and this converted everything to TriL, - # including the Diags. - # Tier 1 was never used. - # Tier 2 was never used (therefore, _BadAdder didn't raise). - op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - self.assertEqual(1, len(op_sum)) - self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular)) - - def test_cannot_add_everything_so_return_more_than_one_operator(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([2.]) - tril5 = linalg.LinearOperatorLowerTriangular([[5.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - ] - # Tier 0 (the only tier) can only convert to Diag, so it combines the two - # diags, but the TriL is unchanged. - # Result should contain two operators, one Diag, one TriL. - op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers) - self.assertEqual(2, len(op_sum)) - found_diag = False - found_tril = False - with self.cached_session(): - for op in op_sum: - if isinstance(op, linalg.LinearOperatorDiag): - found_diag = True - self.assertAllClose([[3.]], op.to_dense().eval()) - if isinstance(op, linalg.LinearOperatorLowerTriangular): - found_tril = True - self.assertAllClose([[5.]], op.to_dense().eval()) - self.assertTrue(found_diag and found_tril) - - def test_intermediate_tier_is_not_skipped(self): - diag1 = linalg.LinearOperatorDiag([1.]) - diag2 = linalg.LinearOperatorDiag([1.]) - tril = linalg.LinearOperatorLowerTriangular([[1.]]) - addition_tiers = [ - [linear_operator_addition._AddAndReturnDiag()], - [_BadAdder()], - [linear_operator_addition._AddAndReturnTriL()], - ] - # tril cannot be added in tier 0, and the intermediate tier 1 with the - # BadAdder will catch it and raise. - with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"): - add_operators([diag1, diag2, tril], addition_tiers=addition_tiers) - - -class AddAndReturnScaledIdentityTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnScaledIdentity() - - def test_identity_plus_identity(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2) - id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.cached_session(): - self.assertAllClose(2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_identity_plus_scaled_identity(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.cached_session(): - self.assertAllClose(3.2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_scaled_identity_plus_scaled_identity(self): - id1 = linalg.LinearOperatorScaledIdentity( - num_rows=2, multiplier=[2.2, 2.2, 2.2]) - id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity)) - - with self.cached_session(): - self.assertAllClose(1.2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnDiagTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnDiag() - - def test_identity_plus_identity_returns_diag(self): - id1 = linalg.LinearOperatorIdentity(num_rows=2) - id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(id1, id2)) - operator = self._adder.add(id1, id2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag)) - - with self.cached_session(): - self.assertAllClose(2 * - linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - def test_diag_plus_diag(self): - diag1 = rng.rand(2, 3, 4) - diag2 = rng.rand(4) - op1 = linalg.LinearOperatorDiag(diag1) - op2 = linalg.LinearOperatorDiag(diag2) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(op1, op2)) - operator = self._adder.add(op1, op2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag)) - - with self.cached_session(): - self.assertAllClose( - linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(), - operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnTriLTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnTriL() - - def test_diag_plus_tril(self): - diag = linalg.LinearOperatorDiag([1., 2.]) - tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]]) - hints = linear_operator_addition._Hints( - is_positive_definite=True, is_non_singular=True) - - self.assertTrue(self._adder.can_add(diag, diag)) - self.assertTrue(self._adder.can_add(diag, tril)) - operator = self._adder.add(diag, tril, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular)) - - with self.cached_session(): - self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval()) - self.assertTrue(operator.is_positive_definite) - self.assertTrue(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -class AddAndReturnMatrixTest(test.TestCase): - - def setUp(self): - self._adder = linear_operator_addition._AddAndReturnMatrix() - - def test_diag_plus_diag(self): - diag1 = linalg.LinearOperatorDiag([1., 2.]) - diag2 = linalg.LinearOperatorDiag([-1., 3.]) - hints = linear_operator_addition._Hints( - is_positive_definite=False, is_non_singular=False) - - self.assertTrue(self._adder.can_add(diag1, diag2)) - operator = self._adder.add(diag1, diag2, "my_operator", hints) - self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix)) - - with self.cached_session(): - self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval()) - self.assertFalse(operator.is_positive_definite) - self.assertFalse(operator.is_non_singular) - self.assertEqual("my_operator", operator.name) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py deleted file mode 100644 index 86130a2c07..0000000000 --- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py +++ /dev/null @@ -1,432 +0,0 @@ -# Copyright 2016 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Add one or more `LinearOperators` efficiently.""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import abc - -import six - -from tensorflow.python.framework import ops -from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops -from tensorflow.python.ops.linalg import linear_operator -from tensorflow.python.ops.linalg import linear_operator_diag -from tensorflow.python.ops.linalg import linear_operator_full_matrix -from tensorflow.python.ops.linalg import linear_operator_identity -from tensorflow.python.ops.linalg import linear_operator_lower_triangular - -__all__ = [] - - -def add_operators(operators, - operator_name=None, - addition_tiers=None, - name=None): - """Efficiently add one or more linear operators. - - Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of - operators `[B1, B2,...]` such that - - ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).``` - - The operators `Bk` result by adding some of the `Ak`, as allowed by - `addition_tiers`. - - Example of efficient adding of diagonal operators. - - ```python - A1 = LinearOperatorDiag(diag=[1., 1.], name="A1") - A2 = LinearOperatorDiag(diag=[2., 2.], name="A2") - - # Use two tiers, the first contains an Adder that returns Diag. Since both - # A1 and A2 are Diag, they can use this Adder. The second tier will not be - # used. - addition_tiers = [ - [_AddAndReturnDiag()], - [_AddAndReturnMatrix()]] - B_list = add_operators([A1, A2], addition_tiers=addition_tiers) - - len(B_list) - ==> 1 - - B_list[0].__class__.__name__ - ==> 'LinearOperatorDiag' - - B_list[0].to_dense() - ==> [[3., 0.], - [0., 3.]] - - B_list[0].name - ==> 'Add/A1__A2/' - ``` - - Args: - operators: Iterable of `LinearOperator` objects with same `dtype`, domain - and range dimensions, and broadcastable batch shapes. - operator_name: String name for returned `LinearOperator`. Defaults to - concatenation of "Add/A__B/" that indicates the order of addition steps. - addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i` - is a list of `Adder` objects. This function attempts to do all additions - in tier `i` before trying tier `i + 1`. - name: A name for this `Op`. Defaults to `add_operators`. - - Returns: - Subclass of `LinearOperator`. Class and order of addition may change as new - (and better) addition strategies emerge. - - Raises: - ValueError: If `operators` argument is empty. - ValueError: If shapes are incompatible. - """ - # Default setting - if addition_tiers is None: - addition_tiers = _DEFAULT_ADDITION_TIERS - - # Argument checking. - check_ops.assert_proper_iterable(operators) - operators = list(reversed(operators)) - if len(operators) < 1: - raise ValueError( - "Argument 'operators' must contain at least one operator. " - "Found: %s" % operators) - if not all( - isinstance(op, linear_operator.LinearOperator) for op in operators): - raise TypeError( - "Argument 'operators' must contain only LinearOperator instances. " - "Found: %s" % operators) - _static_check_for_same_dimensions(operators) - _static_check_for_broadcastable_batch_shape(operators) - - graph_parents = [] - for operator in operators: - graph_parents.extend(operator.graph_parents) - - with ops.name_scope(name or "add_operators", values=graph_parents): - - # Additions done in one of the tiers. Try tier 0, 1,... - ops_to_try_at_next_tier = list(operators) - for tier in addition_tiers: - ops_to_try_at_this_tier = ops_to_try_at_next_tier - ops_to_try_at_next_tier = [] - while ops_to_try_at_this_tier: - op1 = ops_to_try_at_this_tier.pop() - op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier) - if op2 is not None: - # Will try to add the result of this again at this same tier. - new_operator = adder.add(op1, op2, operator_name) - ops_to_try_at_this_tier.append(new_operator) - else: - ops_to_try_at_next_tier.append(op1) - - return ops_to_try_at_next_tier - - -def _pop_a_match_at_tier(op1, operator_list, tier): - # Search from the back of list to the front in order to create nice default - # order of operations. - for i in range(1, len(operator_list) + 1): - op2 = operator_list[-i] - for adder in tier: - if adder.can_add(op1, op2): - return operator_list.pop(-i), adder - return None, None - - -def _infer_hints_allowing_override(op1, op2, hints): - """Infer hints from op1 and op2. hints argument is an override. - - Args: - op1: LinearOperator - op2: LinearOperator - hints: _Hints object holding "is_X" boolean hints to use for returned - operator. - If some hint is None, try to set using op1 and op2. If the - hint is provided, ignore op1 and op2 hints. This allows an override - of previous hints, but does not allow forbidden hints (e.g. you still - cannot say a real diagonal operator is not self-adjoint. - - Returns: - _Hints object. - """ - hints = hints or _Hints() - # If A, B are self-adjoint, then so is A + B. - if hints.is_self_adjoint is None: - is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint - else: - is_self_adjoint = hints.is_self_adjoint - - # If A, B are positive definite, then so is A + B. - if hints.is_positive_definite is None: - is_positive_definite = op1.is_positive_definite and op2.is_positive_definite - else: - is_positive_definite = hints.is_positive_definite - - # A positive definite operator is always non-singular. - if is_positive_definite and hints.is_positive_definite is None: - is_non_singular = True - else: - is_non_singular = hints.is_non_singular - - return _Hints( - is_non_singular=is_non_singular, - is_self_adjoint=is_self_adjoint, - is_positive_definite=is_positive_definite) - - -def _static_check_for_same_dimensions(operators): - """ValueError if operators determined to have different dimensions.""" - if len(operators) < 2: - return - - domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators - if op.domain_dimension.value is not None] - if len(set(value for name, value in domain_dimensions)) > 1: - raise ValueError("Operators must have the same domain dimension. Found: %s" - % domain_dimensions) - - range_dimensions = [(op.name, op.range_dimension.value) for op in operators - if op.range_dimension.value is not None] - if len(set(value for name, value in range_dimensions)) > 1: - raise ValueError("Operators must have the same range dimension. Found: %s" % - range_dimensions) - - -def _static_check_for_broadcastable_batch_shape(operators): - """ValueError if operators determined to have non-broadcastable shapes.""" - if len(operators) < 2: - return - - # This will fail if they cannot be broadcast together. - batch_shape = operators[0].batch_shape - for op in operators[1:]: - batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape) - - -class _Hints(object): - """Holds 'is_X' flags that every LinearOperator is initialized with.""" - - def __init__(self, - is_non_singular=None, - is_positive_definite=None, - is_self_adjoint=None): - self.is_non_singular = is_non_singular - self.is_positive_definite = is_positive_definite - self.is_self_adjoint = is_self_adjoint - - -################################################################################ -# Classes to add two linear operators. -################################################################################ - - -@six.add_metaclass(abc.ABCMeta) -class _Adder(object): - """Abstract base class to add two operators. - - Each `Adder` acts independently, adding everything it can, paying no attention - as to whether another `Adder` could have done the addition more efficiently. - """ - - @property - def name(self): - return self.__class__.__name__ - - @abc.abstractmethod - def can_add(self, op1, op2): - """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`.""" - pass - - @abc.abstractmethod - def _add(self, op1, op2, operator_name, hints): - # Derived classes can assume op1 and op2 have been validated, e.g. they have - # the same dtype, and their domain/range dimensions match. - pass - - def add(self, op1, op2, operator_name, hints=None): - """Return new `LinearOperator` acting like `op1 + op2`. - - Args: - op1: `LinearOperator` - op2: `LinearOperator`, with `shape` and `dtype` such that adding to - `op1` is allowed. - operator_name: `String` name to give to returned `LinearOperator` - hints: `_Hints` object. Returned `LinearOperator` will be created with - these hints. - - Returns: - `LinearOperator` - """ - updated_hints = _infer_hints_allowing_override(op1, op2, hints) - - if operator_name is None: - operator_name = "Add/" + op1.name + "__" + op2.name + "/" - - values = op1.graph_parents + op2.graph_parents - scope_name = self.name - if scope_name.startswith("_"): - scope_name = scope_name[1:] - with ops.name_scope(scope_name, values=values): - return self._add(op1, op2, operator_name, updated_hints) - - -class _AddAndReturnScaledIdentity(_Adder): - """Handles additions resulting in an Identity family member. - - The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family - is closed under addition. This `Adder` respects that, and returns an Identity - """ - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_IDENTITY_FAMILY) - - def _add(self, op1, op2, operator_name, hints): - # Will build a LinearOperatorScaledIdentity. - - if _type(op1) == _SCALED_IDENTITY: - multiplier_1 = op1.multiplier - else: - multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype) - - if _type(op2) == _SCALED_IDENTITY: - multiplier_2 = op2.multiplier - else: - multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype) - - return linear_operator_identity.LinearOperatorScaledIdentity( - num_rows=op1.range_dimension_tensor(), - multiplier=multiplier_1 + multiplier_2, - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnDiag(_Adder): - """Handles additions resulting in a Diag operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE) - - def _add(self, op1, op2, operator_name, hints): - return linear_operator_diag.LinearOperatorDiag( - diag=op1.diag_part() + op2.diag_part(), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnTriL(_Adder): - """Handles additions resulting in a TriL operator.""" - - def can_add(self, op1, op2): - types = {_type(op1), _type(op2)} - return not types.difference(_DIAG_LIKE.union({_TRIL})) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - - return linear_operator_lower_triangular.LinearOperatorLowerTriangular( - tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -class _AddAndReturnMatrix(_Adder): - """"Handles additions resulting in a `LinearOperatorFullMatrix`.""" - - def can_add(self, op1, op2): # pylint: disable=unused-argument - return isinstance(op1, linear_operator.LinearOperator) and isinstance( - op2, linear_operator.LinearOperator) - - def _add(self, op1, op2, operator_name, hints): - if _type(op1) in _EFFICIENT_ADD_TO_TENSOR: - op_add_to_tensor, op_other = op1, op2 - else: - op_add_to_tensor, op_other = op2, op1 - return linear_operator_full_matrix.LinearOperatorFullMatrix( - matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()), - is_non_singular=hints.is_non_singular, - is_self_adjoint=hints.is_self_adjoint, - is_positive_definite=hints.is_positive_definite, - name=operator_name) - - -################################################################################ -# Constants designating types of LinearOperators -################################################################################ - -# Type name constants for LinearOperator classes. -_IDENTITY = "identity" -_SCALED_IDENTITY = "scaled_identity" -_DIAG = "diag" -_TRIL = "tril" -_MATRIX = "matrix" - -# Groups of operators. -_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY} -_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY} -# operators with an efficient .add_to_tensor() method. -_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE - - -def _type(operator): - """Returns the type name constant (e.g. _TRIL) for operator.""" - if isinstance(operator, linear_operator_diag.LinearOperatorDiag): - return _DIAG - if isinstance(operator, - linear_operator_lower_triangular.LinearOperatorLowerTriangular): - return _TRIL - if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix): - return _MATRIX - if isinstance(operator, linear_operator_identity.LinearOperatorIdentity): - return _IDENTITY - if isinstance(operator, - linear_operator_identity.LinearOperatorScaledIdentity): - return _SCALED_IDENTITY - raise TypeError("Operator type unknown: %s" % operator) - - -################################################################################ -# Addition tiers: -# We attempt to use Adders in tier K before K+1. -# -# Organize tiers to -# (i) reduce O(..) complexity of forming final operator, and -# (ii) produce the "most efficient" final operator. -# Dev notes: -# * Results of addition at tier K will be added at tier K or higher. -# * Tiers may change, and we warn the user that it may change. -################################################################################ - -# Note that the final tier, _AddAndReturnMatrix, will convert everything to a -# dense matrix. So it is sometimes very inefficient. -_DEFAULT_ADDITION_TIERS = [ - [_AddAndReturnScaledIdentity()], - [_AddAndReturnDiag()], - [_AddAndReturnTriL()], - [_AddAndReturnMatrix()], -] -- cgit v1.2.3