aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Ian Langmore <langmore@google.com>2018-09-17 15:46:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-17 15:49:55 -0700
commitd5f4c3aa59aebc88f42a186a30ef6200857194ca (patch)
treea20cce71bcf259362f06c8975b6d8c2ca08fba6e
parent3365cd1cc7bf3dcb781c76652132119bf82133e6 (diff)
Remove tensorflow/contrib/linalg library. linalg remains in core.
PiperOrigin-RevId: 213352573
-rw-r--r--CODEOWNERS1
-rw-r--r--tensorflow/contrib/BUILD1
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt3
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake1
-rw-r--r--tensorflow/contrib/distributions/BUILD54
-rw-r--r--tensorflow/contrib/linalg/BUILD44
-rw-r--r--tensorflow/contrib/linalg/__init__.py58
-rw-r--r--tensorflow/contrib/linalg/python/__init__.py19
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py412
-rw-r--r--tensorflow/contrib/linalg/python/ops/linear_operator_addition.py432
11 files changed, 26 insertions, 1000 deletions
diff --git a/CODEOWNERS b/CODEOWNERS
index b612bccffb..94cc865479 100644
--- a/CODEOWNERS
+++ b/CODEOWNERS
@@ -41,7 +41,6 @@
/tensorflow/contrib/labeled_tensor/ @shoyer
/tensorflow/contrib/layers/ @fchollet @martinwicke
/tensorflow/contrib/learn/ @martinwicke @ispirmustafa @alextp
-/tensorflow/contrib/linalg/ @langmore
/tensorflow/contrib/linear_optimizer/ @petrosmol @andreasst @katsiapis
/tensorflow/contrib/lookup/ @ysuematsu @andreasst
/tensorflow/contrib/losses/ @alextp @ispirmustafa
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d98a24994c..e1af52cd96 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -60,7 +60,6 @@ py_library(
"//tensorflow/contrib/learn",
"//tensorflow/contrib/legacy_seq2seq:seq2seq_py",
"//tensorflow/contrib/libsvm",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/contrib/linear_optimizer:sdca_estimator_py",
"//tensorflow/contrib/linear_optimizer:sdca_ops_py",
"//tensorflow/contrib/lite/python:lite",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 9478e42b46..e71b0e0ae3 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -63,7 +63,6 @@ from tensorflow.contrib import labeled_tensor
from tensorflow.contrib import layers
from tensorflow.contrib import learn
from tensorflow.contrib import legacy_seq2seq
-from tensorflow.contrib import linalg
from tensorflow.contrib import linear_optimizer
from tensorflow.contrib import lookup
from tensorflow.contrib import losses
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fb871acae9..1c432b6e0b 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -273,9 +273,6 @@ tensorflow/contrib/libsvm
tensorflow/contrib/libsvm/python
tensorflow/contrib/libsvm/python/kernel_tests
tensorflow/contrib/libsvm/python/ops
-tensorflow/contrib/linalg
-tensorflow/contrib/linalg/python
-tensorflow/contrib/linalg/python/ops
tensorflow/contrib/linear_optimizer
tensorflow/contrib/linear_optimizer/kernels
tensorflow/contrib/linear_optimizer/kernels/g3doc
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 2c878c1716..ed31351d9e 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -183,7 +183,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
file(GLOB_RECURSE tf_test_src_py
${tf_test_src_py}
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
- "${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD
index 9aadc634da..3ff7da4f89 100644
--- a/tensorflow/contrib/distributions/BUILD
+++ b/tensorflow/contrib/distributions/BUILD
@@ -25,7 +25,6 @@ py_library(
"`tf.contrib.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:clip_ops",
@@ -61,7 +60,6 @@ py_library(
":bijectors_py",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/learn",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:control_flow_ops",
@@ -706,8 +704,8 @@ cuda_py_test(
":bijectors_py",
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -722,8 +720,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/ops/linalg",
],
shard_count = 4,
tags = ["noasan"], # times out, http://b/78588814
@@ -739,8 +737,8 @@ cuda_py_test(
additional_deps = [
":distributions_py",
"//third_party/py/numpy",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
@@ -794,8 +792,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -831,8 +829,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -852,8 +850,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -871,8 +869,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -907,8 +905,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -926,10 +924,10 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
@@ -945,8 +943,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -964,8 +962,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -983,8 +981,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1002,8 +1000,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1021,8 +1019,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1040,8 +1038,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1075,8 +1073,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1126,8 +1124,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1161,8 +1159,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1180,8 +1178,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1201,8 +1199,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1221,8 +1219,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1240,8 +1238,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1259,8 +1257,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1278,8 +1276,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1297,8 +1295,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
@@ -1316,8 +1314,8 @@ cuda_py_test(
":distributions_py",
"//third_party/py/numpy",
"@six_archive//:six",
- "//tensorflow/contrib/linalg:linalg_py",
"//tensorflow/python:array_ops",
+ "//tensorflow/python/ops/linalg",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
deleted file mode 100644
index 78b7970069..0000000000
--- a/tensorflow/contrib/linalg/BUILD
+++ /dev/null
@@ -1,44 +0,0 @@
-# Description:
-# Contains classes that provide access to common method of a [batch] matrix,
-# without the need to instantiate the matrix.
-# This allows for exploitation of structure, as well as a generic interface
-# suitable for iterative solvers.
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-package(default_visibility = ["//tensorflow:__subpackages__"])
-
-load("//tensorflow:tensorflow.bzl", "cuda_py_test")
-
-py_library(
- name = "linalg_py",
- srcs = ["__init__.py"] + glob(["python/ops/*.py"]),
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:util",
- "//tensorflow/python/ops/linalg",
- "@six_archive//:six",
- ],
-)
-
-cuda_py_test(
- name = "linear_operator_addition_test",
- size = "small",
- srcs = ["python/kernel_tests/linear_operator_addition_test.py"],
- additional_deps = [
- ":linalg_py",
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:framework",
- "//tensorflow/python:framework_for_generated_wrappers",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:platform_test",
- ],
-)
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
deleted file mode 100644
index cbe4c03e4d..0000000000
--- a/tensorflow/contrib/linalg/__init__.py
+++ /dev/null
@@ -1,58 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Linear algebra libraries.
-
-See the[Contrib Linalg](https://tensorflow.org/api_guides/python/contrib.linalg)
-guide.
-
-@@LinearOperator
-@@LinearOperatorBlockDiag
-@@LinearOperatorCirculant
-@@LinearOperatorCirculant2D
-@@LinearOperatorCirculant3D
-@@LinearOperatorDiag
-@@LinearOperatorIdentity
-@@LinearOperatorScaledIdentity
-@@LinearOperatorFullMatrix
-@@LinearOperatorKronecker
-@@LinearOperatorLowerTriangular
-@@LinearOperatorLowRankUpdate
-@@LinearOperatorComposition
-@@add_operators
-
-"""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-# pylint: disable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
-from tensorflow.python.ops.linalg.linear_operator import *
-from tensorflow.python.ops.linalg.linear_operator_block_diag import *
-from tensorflow.python.ops.linalg.linear_operator_circulant import *
-from tensorflow.python.ops.linalg.linear_operator_composition import *
-from tensorflow.python.ops.linalg.linear_operator_diag import *
-from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
-from tensorflow.python.ops.linalg.linear_operator_identity import *
-from tensorflow.python.ops.linalg.linear_operator_kronecker import *
-from tensorflow.python.ops.linalg.linear_operator_low_rank_update import *
-from tensorflow.python.ops.linalg.linear_operator_lower_triangular import *
-
-# pylint: enable=unused-import,wildcard-import,line-too-long,g-importing-member
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-remove_undocumented(__name__)
diff --git a/tensorflow/contrib/linalg/python/__init__.py b/tensorflow/contrib/linalg/python/__init__.py
deleted file mode 100644
index c5ca3a623f..0000000000
--- a/tensorflow/contrib/linalg/python/__init__.py
+++ /dev/null
@@ -1,19 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""ops module."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
deleted file mode 100644
index d94ac73654..0000000000
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_addition_test.py
+++ /dev/null
@@ -1,412 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import numpy as np
-
-from tensorflow.contrib.linalg.python.ops import linear_operator_addition
-from tensorflow.python.framework import random_seed
-from tensorflow.python.ops import linalg_ops
-from tensorflow.python.ops.linalg import linalg as linalg_lib
-from tensorflow.python.platform import test
-
-linalg = linalg_lib
-random_seed.set_random_seed(23)
-rng = np.random.RandomState(0)
-
-add_operators = linear_operator_addition.add_operators
-
-
-# pylint: disable=unused-argument
-class _BadAdder(linear_operator_addition._Adder):
- """Adder that will fail if used."""
-
- def can_add(self, op1, op2):
- raise AssertionError("BadAdder.can_add called!")
-
- def _add(self, op1, op2, operator_name, hints):
- raise AssertionError("This line should not be reached")
-
-
-# pylint: enable=unused-argument
-
-
-class LinearOperatorAdditionCorrectnessTest(test.TestCase):
- """Tests correctness of addition with combinations of a few Adders.
-
- Tests here are done with the _DEFAULT_ADDITION_TIERS, which means
- add_operators should reduce all operators resulting in one single operator.
-
- This shows that we are able to correctly combine adders using the tiered
- system. All Adders should be tested separately, and there is no need to test
- every Adder within this class.
- """
-
- def test_one_operator_is_returned_unchanged(self):
- op_a = linalg.LinearOperatorDiag([1., 1.])
- op_sum = add_operators([op_a])
- self.assertEqual(1, len(op_sum))
- self.assertTrue(op_sum[0] is op_a)
-
- def test_at_least_one_operators_required(self):
- with self.assertRaisesRegexp(ValueError, "must contain at least one"):
- add_operators([])
-
- def test_attempting_to_add_numbers_raises(self):
- with self.assertRaisesRegexp(TypeError, "contain only LinearOperator"):
- add_operators([1, 2])
-
- def test_two_diag_operators(self):
- op_a = linalg.LinearOperatorDiag(
- [1., 1.], is_positive_definite=True, name="A")
- op_b = linalg.LinearOperatorDiag(
- [2., 2.], is_positive_definite=True, name="B")
- with self.cached_session():
- op_sum = add_operators([op_a, op_b])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
- self.assertAllClose([[3., 0.], [0., 3.]], op.to_dense().eval())
- # Adding positive definite operators produces positive def.
- self.assertTrue(op.is_positive_definite)
- # Real diagonal ==> self-adjoint.
- self.assertTrue(op.is_self_adjoint)
- # Positive definite ==> non-singular
- self.assertTrue(op.is_non_singular)
- # Enforce particular name for this simple case
- self.assertEqual("Add/B__A/", op.name)
-
- def test_three_diag_operators(self):
- op1 = linalg.LinearOperatorDiag(
- [1., 1.], is_positive_definite=True, name="op1")
- op2 = linalg.LinearOperatorDiag(
- [2., 2.], is_positive_definite=True, name="op2")
- op3 = linalg.LinearOperatorDiag(
- [3., 3.], is_positive_definite=True, name="op3")
- with self.cached_session():
- op_sum = add_operators([op1, op2, op3])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorDiag))
- self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
- # Adding positive definite operators produces positive def.
- self.assertTrue(op.is_positive_definite)
- # Real diagonal ==> self-adjoint.
- self.assertTrue(op.is_self_adjoint)
- # Positive definite ==> non-singular
- self.assertTrue(op.is_non_singular)
-
- def test_diag_tril_diag(self):
- op1 = linalg.LinearOperatorDiag(
- [1., 1.], is_non_singular=True, name="diag_a")
- op2 = linalg.LinearOperatorLowerTriangular(
- [[2., 0.], [0., 2.]],
- is_self_adjoint=True,
- is_non_singular=True,
- name="tril")
- op3 = linalg.LinearOperatorDiag(
- [3., 3.], is_non_singular=True, name="diag_b")
- with self.cached_session():
- op_sum = add_operators([op1, op2, op3])
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorLowerTriangular))
- self.assertAllClose([[6., 0.], [0., 6.]], op.to_dense().eval())
-
- # The diag operators will be self-adjoint (because real and diagonal).
- # The TriL operator has the self-adjoint hint set.
- self.assertTrue(op.is_self_adjoint)
-
- # Even though op1/2/3 are non-singular, this does not imply op is.
- # Since no custom hint was provided, we default to None (unknown).
- self.assertEqual(None, op.is_non_singular)
-
- def test_matrix_diag_tril_diag_uses_custom_name(self):
- op0 = linalg.LinearOperatorFullMatrix(
- [[-1., -1.], [-1., -1.]], name="matrix")
- op1 = linalg.LinearOperatorDiag([1., 1.], name="diag_a")
- op2 = linalg.LinearOperatorLowerTriangular(
- [[2., 0.], [1.5, 2.]], name="tril")
- op3 = linalg.LinearOperatorDiag([3., 3.], name="diag_b")
- with self.cached_session():
- op_sum = add_operators([op0, op1, op2, op3], operator_name="my_operator")
- self.assertEqual(1, len(op_sum))
- op = op_sum[0]
- self.assertTrue(isinstance(op, linalg_lib.LinearOperatorFullMatrix))
- self.assertAllClose([[5., -1.], [0.5, 5.]], op.to_dense().eval())
- self.assertEqual("my_operator", op.name)
-
- def test_incompatible_domain_dimensions_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(2, 4))
- with self.assertRaisesRegexp(ValueError, "must.*same domain dimension"):
- add_operators([op1, op2])
-
- def test_incompatible_range_dimensions_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(3, 3))
- with self.assertRaisesRegexp(ValueError, "must.*same range dimension"):
- add_operators([op1, op2])
-
- def test_non_broadcastable_batch_shape_raises(self):
- op1 = linalg.LinearOperatorFullMatrix(rng.rand(2, 3, 3))
- op2 = linalg.LinearOperatorDiag(rng.rand(4, 3, 3))
- with self.assertRaisesRegexp(ValueError, "Incompatible shapes"):
- add_operators([op1, op2])
-
-
-class LinearOperatorOrderOfAdditionTest(test.TestCase):
- """Test that the order of addition is done as specified by tiers."""
-
- def test_tier_0_additions_done_in_tier_0(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- diag3 = linalg.LinearOperatorDiag([1.])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- ]
- # Should not raise since all were added in tier 0, and tier 1 (with the
- # _BadAdder) was never reached.
- op_sum = add_operators([diag1, diag2, diag3], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorDiag))
-
- def test_tier_1_additions_done_by_tier_1(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [linear_operator_addition._AddAndReturnTriL()],
- [_BadAdder()],
- ]
- # Should not raise since all were added by tier 1, and the
- # _BadAdder) was never reached.
- op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
- def test_tier_1_additions_done_by_tier_1_with_order_flipped(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnTriL()],
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- ]
- # Tier 0 could convert to TriL, and this converted everything to TriL,
- # including the Diags.
- # Tier 1 was never used.
- # Tier 2 was never used (therefore, _BadAdder didn't raise).
- op_sum = add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
- self.assertEqual(1, len(op_sum))
- self.assertTrue(isinstance(op_sum[0], linalg.LinearOperatorLowerTriangular))
-
- def test_cannot_add_everything_so_return_more_than_one_operator(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([2.])
- tril5 = linalg.LinearOperatorLowerTriangular([[5.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- ]
- # Tier 0 (the only tier) can only convert to Diag, so it combines the two
- # diags, but the TriL is unchanged.
- # Result should contain two operators, one Diag, one TriL.
- op_sum = add_operators([diag1, diag2, tril5], addition_tiers=addition_tiers)
- self.assertEqual(2, len(op_sum))
- found_diag = False
- found_tril = False
- with self.cached_session():
- for op in op_sum:
- if isinstance(op, linalg.LinearOperatorDiag):
- found_diag = True
- self.assertAllClose([[3.]], op.to_dense().eval())
- if isinstance(op, linalg.LinearOperatorLowerTriangular):
- found_tril = True
- self.assertAllClose([[5.]], op.to_dense().eval())
- self.assertTrue(found_diag and found_tril)
-
- def test_intermediate_tier_is_not_skipped(self):
- diag1 = linalg.LinearOperatorDiag([1.])
- diag2 = linalg.LinearOperatorDiag([1.])
- tril = linalg.LinearOperatorLowerTriangular([[1.]])
- addition_tiers = [
- [linear_operator_addition._AddAndReturnDiag()],
- [_BadAdder()],
- [linear_operator_addition._AddAndReturnTriL()],
- ]
- # tril cannot be added in tier 0, and the intermediate tier 1 with the
- # BadAdder will catch it and raise.
- with self.assertRaisesRegexp(AssertionError, "BadAdder.can_add called"):
- add_operators([diag1, diag2, tril], addition_tiers=addition_tiers)
-
-
-class AddAndReturnScaledIdentityTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnScaledIdentity()
-
- def test_identity_plus_identity(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2)
- id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.cached_session():
- self.assertAllClose(2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_identity_plus_scaled_identity(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=2.2)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.cached_session():
- self.assertAllClose(3.2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_scaled_identity_plus_scaled_identity(self):
- id1 = linalg.LinearOperatorScaledIdentity(
- num_rows=2, multiplier=[2.2, 2.2, 2.2])
- id2 = linalg.LinearOperatorScaledIdentity(num_rows=2, multiplier=-1.0)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorScaledIdentity))
-
- with self.cached_session():
- self.assertAllClose(1.2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnDiagTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnDiag()
-
- def test_identity_plus_identity_returns_diag(self):
- id1 = linalg.LinearOperatorIdentity(num_rows=2)
- id2 = linalg.LinearOperatorIdentity(num_rows=2, batch_shape=[3])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(id1, id2))
- operator = self._adder.add(id1, id2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
- with self.cached_session():
- self.assertAllClose(2 *
- linalg_ops.eye(num_rows=2, batch_shape=[3]).eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
- def test_diag_plus_diag(self):
- diag1 = rng.rand(2, 3, 4)
- diag2 = rng.rand(4)
- op1 = linalg.LinearOperatorDiag(diag1)
- op2 = linalg.LinearOperatorDiag(diag2)
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(op1, op2))
- operator = self._adder.add(op1, op2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorDiag))
-
- with self.cached_session():
- self.assertAllClose(
- linalg.LinearOperatorDiag(diag1 + diag2).to_dense().eval(),
- operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnTriLTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnTriL()
-
- def test_diag_plus_tril(self):
- diag = linalg.LinearOperatorDiag([1., 2.])
- tril = linalg.LinearOperatorLowerTriangular([[10., 0.], [30., 0.]])
- hints = linear_operator_addition._Hints(
- is_positive_definite=True, is_non_singular=True)
-
- self.assertTrue(self._adder.can_add(diag, diag))
- self.assertTrue(self._adder.can_add(diag, tril))
- operator = self._adder.add(diag, tril, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorLowerTriangular))
-
- with self.cached_session():
- self.assertAllClose([[11., 0.], [30., 2.]], operator.to_dense().eval())
- self.assertTrue(operator.is_positive_definite)
- self.assertTrue(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-class AddAndReturnMatrixTest(test.TestCase):
-
- def setUp(self):
- self._adder = linear_operator_addition._AddAndReturnMatrix()
-
- def test_diag_plus_diag(self):
- diag1 = linalg.LinearOperatorDiag([1., 2.])
- diag2 = linalg.LinearOperatorDiag([-1., 3.])
- hints = linear_operator_addition._Hints(
- is_positive_definite=False, is_non_singular=False)
-
- self.assertTrue(self._adder.can_add(diag1, diag2))
- operator = self._adder.add(diag1, diag2, "my_operator", hints)
- self.assertTrue(isinstance(operator, linalg.LinearOperatorFullMatrix))
-
- with self.cached_session():
- self.assertAllClose([[0., 0.], [0., 5.]], operator.to_dense().eval())
- self.assertFalse(operator.is_positive_definite)
- self.assertFalse(operator.is_non_singular)
- self.assertEqual("my_operator", operator.name)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py b/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
deleted file mode 100644
index 86130a2c07..0000000000
--- a/tensorflow/contrib/linalg/python/ops/linear_operator_addition.py
+++ /dev/null
@@ -1,432 +0,0 @@
-# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Add one or more `LinearOperators` efficiently."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import abc
-
-import six
-
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops.linalg import linear_operator
-from tensorflow.python.ops.linalg import linear_operator_diag
-from tensorflow.python.ops.linalg import linear_operator_full_matrix
-from tensorflow.python.ops.linalg import linear_operator_identity
-from tensorflow.python.ops.linalg import linear_operator_lower_triangular
-
-__all__ = []
-
-
-def add_operators(operators,
- operator_name=None,
- addition_tiers=None,
- name=None):
- """Efficiently add one or more linear operators.
-
- Given operators `[A1, A2,...]`, this `Op` returns a possibly shorter list of
- operators `[B1, B2,...]` such that
-
- ```sum_k Ak.matmul(x) = sum_k Bk.matmul(x).```
-
- The operators `Bk` result by adding some of the `Ak`, as allowed by
- `addition_tiers`.
-
- Example of efficient adding of diagonal operators.
-
- ```python
- A1 = LinearOperatorDiag(diag=[1., 1.], name="A1")
- A2 = LinearOperatorDiag(diag=[2., 2.], name="A2")
-
- # Use two tiers, the first contains an Adder that returns Diag. Since both
- # A1 and A2 are Diag, they can use this Adder. The second tier will not be
- # used.
- addition_tiers = [
- [_AddAndReturnDiag()],
- [_AddAndReturnMatrix()]]
- B_list = add_operators([A1, A2], addition_tiers=addition_tiers)
-
- len(B_list)
- ==> 1
-
- B_list[0].__class__.__name__
- ==> 'LinearOperatorDiag'
-
- B_list[0].to_dense()
- ==> [[3., 0.],
- [0., 3.]]
-
- B_list[0].name
- ==> 'Add/A1__A2/'
- ```
-
- Args:
- operators: Iterable of `LinearOperator` objects with same `dtype`, domain
- and range dimensions, and broadcastable batch shapes.
- operator_name: String name for returned `LinearOperator`. Defaults to
- concatenation of "Add/A__B/" that indicates the order of addition steps.
- addition_tiers: List tiers, like `[tier_0, tier_1, ...]`, where `tier_i`
- is a list of `Adder` objects. This function attempts to do all additions
- in tier `i` before trying tier `i + 1`.
- name: A name for this `Op`. Defaults to `add_operators`.
-
- Returns:
- Subclass of `LinearOperator`. Class and order of addition may change as new
- (and better) addition strategies emerge.
-
- Raises:
- ValueError: If `operators` argument is empty.
- ValueError: If shapes are incompatible.
- """
- # Default setting
- if addition_tiers is None:
- addition_tiers = _DEFAULT_ADDITION_TIERS
-
- # Argument checking.
- check_ops.assert_proper_iterable(operators)
- operators = list(reversed(operators))
- if len(operators) < 1:
- raise ValueError(
- "Argument 'operators' must contain at least one operator. "
- "Found: %s" % operators)
- if not all(
- isinstance(op, linear_operator.LinearOperator) for op in operators):
- raise TypeError(
- "Argument 'operators' must contain only LinearOperator instances. "
- "Found: %s" % operators)
- _static_check_for_same_dimensions(operators)
- _static_check_for_broadcastable_batch_shape(operators)
-
- graph_parents = []
- for operator in operators:
- graph_parents.extend(operator.graph_parents)
-
- with ops.name_scope(name or "add_operators", values=graph_parents):
-
- # Additions done in one of the tiers. Try tier 0, 1,...
- ops_to_try_at_next_tier = list(operators)
- for tier in addition_tiers:
- ops_to_try_at_this_tier = ops_to_try_at_next_tier
- ops_to_try_at_next_tier = []
- while ops_to_try_at_this_tier:
- op1 = ops_to_try_at_this_tier.pop()
- op2, adder = _pop_a_match_at_tier(op1, ops_to_try_at_this_tier, tier)
- if op2 is not None:
- # Will try to add the result of this again at this same tier.
- new_operator = adder.add(op1, op2, operator_name)
- ops_to_try_at_this_tier.append(new_operator)
- else:
- ops_to_try_at_next_tier.append(op1)
-
- return ops_to_try_at_next_tier
-
-
-def _pop_a_match_at_tier(op1, operator_list, tier):
- # Search from the back of list to the front in order to create nice default
- # order of operations.
- for i in range(1, len(operator_list) + 1):
- op2 = operator_list[-i]
- for adder in tier:
- if adder.can_add(op1, op2):
- return operator_list.pop(-i), adder
- return None, None
-
-
-def _infer_hints_allowing_override(op1, op2, hints):
- """Infer hints from op1 and op2. hints argument is an override.
-
- Args:
- op1: LinearOperator
- op2: LinearOperator
- hints: _Hints object holding "is_X" boolean hints to use for returned
- operator.
- If some hint is None, try to set using op1 and op2. If the
- hint is provided, ignore op1 and op2 hints. This allows an override
- of previous hints, but does not allow forbidden hints (e.g. you still
- cannot say a real diagonal operator is not self-adjoint.
-
- Returns:
- _Hints object.
- """
- hints = hints or _Hints()
- # If A, B are self-adjoint, then so is A + B.
- if hints.is_self_adjoint is None:
- is_self_adjoint = op1.is_self_adjoint and op2.is_self_adjoint
- else:
- is_self_adjoint = hints.is_self_adjoint
-
- # If A, B are positive definite, then so is A + B.
- if hints.is_positive_definite is None:
- is_positive_definite = op1.is_positive_definite and op2.is_positive_definite
- else:
- is_positive_definite = hints.is_positive_definite
-
- # A positive definite operator is always non-singular.
- if is_positive_definite and hints.is_positive_definite is None:
- is_non_singular = True
- else:
- is_non_singular = hints.is_non_singular
-
- return _Hints(
- is_non_singular=is_non_singular,
- is_self_adjoint=is_self_adjoint,
- is_positive_definite=is_positive_definite)
-
-
-def _static_check_for_same_dimensions(operators):
- """ValueError if operators determined to have different dimensions."""
- if len(operators) < 2:
- return
-
- domain_dimensions = [(op.name, op.domain_dimension.value) for op in operators
- if op.domain_dimension.value is not None]
- if len(set(value for name, value in domain_dimensions)) > 1:
- raise ValueError("Operators must have the same domain dimension. Found: %s"
- % domain_dimensions)
-
- range_dimensions = [(op.name, op.range_dimension.value) for op in operators
- if op.range_dimension.value is not None]
- if len(set(value for name, value in range_dimensions)) > 1:
- raise ValueError("Operators must have the same range dimension. Found: %s" %
- range_dimensions)
-
-
-def _static_check_for_broadcastable_batch_shape(operators):
- """ValueError if operators determined to have non-broadcastable shapes."""
- if len(operators) < 2:
- return
-
- # This will fail if they cannot be broadcast together.
- batch_shape = operators[0].batch_shape
- for op in operators[1:]:
- batch_shape = array_ops.broadcast_static_shape(batch_shape, op.batch_shape)
-
-
-class _Hints(object):
- """Holds 'is_X' flags that every LinearOperator is initialized with."""
-
- def __init__(self,
- is_non_singular=None,
- is_positive_definite=None,
- is_self_adjoint=None):
- self.is_non_singular = is_non_singular
- self.is_positive_definite = is_positive_definite
- self.is_self_adjoint = is_self_adjoint
-
-
-################################################################################
-# Classes to add two linear operators.
-################################################################################
-
-
-@six.add_metaclass(abc.ABCMeta)
-class _Adder(object):
- """Abstract base class to add two operators.
-
- Each `Adder` acts independently, adding everything it can, paying no attention
- as to whether another `Adder` could have done the addition more efficiently.
- """
-
- @property
- def name(self):
- return self.__class__.__name__
-
- @abc.abstractmethod
- def can_add(self, op1, op2):
- """Returns `True` if this `Adder` can add `op1` and `op2`. Else `False`."""
- pass
-
- @abc.abstractmethod
- def _add(self, op1, op2, operator_name, hints):
- # Derived classes can assume op1 and op2 have been validated, e.g. they have
- # the same dtype, and their domain/range dimensions match.
- pass
-
- def add(self, op1, op2, operator_name, hints=None):
- """Return new `LinearOperator` acting like `op1 + op2`.
-
- Args:
- op1: `LinearOperator`
- op2: `LinearOperator`, with `shape` and `dtype` such that adding to
- `op1` is allowed.
- operator_name: `String` name to give to returned `LinearOperator`
- hints: `_Hints` object. Returned `LinearOperator` will be created with
- these hints.
-
- Returns:
- `LinearOperator`
- """
- updated_hints = _infer_hints_allowing_override(op1, op2, hints)
-
- if operator_name is None:
- operator_name = "Add/" + op1.name + "__" + op2.name + "/"
-
- values = op1.graph_parents + op2.graph_parents
- scope_name = self.name
- if scope_name.startswith("_"):
- scope_name = scope_name[1:]
- with ops.name_scope(scope_name, values=values):
- return self._add(op1, op2, operator_name, updated_hints)
-
-
-class _AddAndReturnScaledIdentity(_Adder):
- """Handles additions resulting in an Identity family member.
-
- The Identity (`LinearOperatorScaledIdentity`, `LinearOperatorIdentity`) family
- is closed under addition. This `Adder` respects that, and returns an Identity
- """
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_IDENTITY_FAMILY)
-
- def _add(self, op1, op2, operator_name, hints):
- # Will build a LinearOperatorScaledIdentity.
-
- if _type(op1) == _SCALED_IDENTITY:
- multiplier_1 = op1.multiplier
- else:
- multiplier_1 = array_ops.ones(op1.batch_shape_tensor(), dtype=op1.dtype)
-
- if _type(op2) == _SCALED_IDENTITY:
- multiplier_2 = op2.multiplier
- else:
- multiplier_2 = array_ops.ones(op2.batch_shape_tensor(), dtype=op2.dtype)
-
- return linear_operator_identity.LinearOperatorScaledIdentity(
- num_rows=op1.range_dimension_tensor(),
- multiplier=multiplier_1 + multiplier_2,
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnDiag(_Adder):
- """Handles additions resulting in a Diag operator."""
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_DIAG_LIKE)
-
- def _add(self, op1, op2, operator_name, hints):
- return linear_operator_diag.LinearOperatorDiag(
- diag=op1.diag_part() + op2.diag_part(),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnTriL(_Adder):
- """Handles additions resulting in a TriL operator."""
-
- def can_add(self, op1, op2):
- types = {_type(op1), _type(op2)}
- return not types.difference(_DIAG_LIKE.union({_TRIL}))
-
- def _add(self, op1, op2, operator_name, hints):
- if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
- op_add_to_tensor, op_other = op1, op2
- else:
- op_add_to_tensor, op_other = op2, op1
-
- return linear_operator_lower_triangular.LinearOperatorLowerTriangular(
- tril=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-class _AddAndReturnMatrix(_Adder):
- """"Handles additions resulting in a `LinearOperatorFullMatrix`."""
-
- def can_add(self, op1, op2): # pylint: disable=unused-argument
- return isinstance(op1, linear_operator.LinearOperator) and isinstance(
- op2, linear_operator.LinearOperator)
-
- def _add(self, op1, op2, operator_name, hints):
- if _type(op1) in _EFFICIENT_ADD_TO_TENSOR:
- op_add_to_tensor, op_other = op1, op2
- else:
- op_add_to_tensor, op_other = op2, op1
- return linear_operator_full_matrix.LinearOperatorFullMatrix(
- matrix=op_add_to_tensor.add_to_tensor(op_other.to_dense()),
- is_non_singular=hints.is_non_singular,
- is_self_adjoint=hints.is_self_adjoint,
- is_positive_definite=hints.is_positive_definite,
- name=operator_name)
-
-
-################################################################################
-# Constants designating types of LinearOperators
-################################################################################
-
-# Type name constants for LinearOperator classes.
-_IDENTITY = "identity"
-_SCALED_IDENTITY = "scaled_identity"
-_DIAG = "diag"
-_TRIL = "tril"
-_MATRIX = "matrix"
-
-# Groups of operators.
-_DIAG_LIKE = {_DIAG, _IDENTITY, _SCALED_IDENTITY}
-_IDENTITY_FAMILY = {_IDENTITY, _SCALED_IDENTITY}
-# operators with an efficient .add_to_tensor() method.
-_EFFICIENT_ADD_TO_TENSOR = _DIAG_LIKE
-
-
-def _type(operator):
- """Returns the type name constant (e.g. _TRIL) for operator."""
- if isinstance(operator, linear_operator_diag.LinearOperatorDiag):
- return _DIAG
- if isinstance(operator,
- linear_operator_lower_triangular.LinearOperatorLowerTriangular):
- return _TRIL
- if isinstance(operator, linear_operator_full_matrix.LinearOperatorFullMatrix):
- return _MATRIX
- if isinstance(operator, linear_operator_identity.LinearOperatorIdentity):
- return _IDENTITY
- if isinstance(operator,
- linear_operator_identity.LinearOperatorScaledIdentity):
- return _SCALED_IDENTITY
- raise TypeError("Operator type unknown: %s" % operator)
-
-
-################################################################################
-# Addition tiers:
-# We attempt to use Adders in tier K before K+1.
-#
-# Organize tiers to
-# (i) reduce O(..) complexity of forming final operator, and
-# (ii) produce the "most efficient" final operator.
-# Dev notes:
-# * Results of addition at tier K will be added at tier K or higher.
-# * Tiers may change, and we warn the user that it may change.
-################################################################################
-
-# Note that the final tier, _AddAndReturnMatrix, will convert everything to a
-# dense matrix. So it is sometimes very inefficient.
-_DEFAULT_ADDITION_TIERS = [
- [_AddAndReturnScaledIdentity()],
- [_AddAndReturnDiag()],
- [_AddAndReturnTriL()],
- [_AddAndReturnMatrix()],
-]