aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/linalg
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-13 15:23:08 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 15:25:38 -0700
commita77dcb5e56dbbbcc3383cb0b39cd79dd88135635 (patch)
tree1716ba71ef9e71d643c15f763e681b1d94413ed3 /tensorflow/contrib/linalg
parent1298c3240aa9f36b79ea7f0e772edfff87381771 (diff)
Add broadcasting to all LinearOperators.
This will broadcast in cases where batch shapes are not equal (but tries to determine statically if this is the case). The broadcasting is not as efficient as doing the broadcast in C++, but makes for the API to at least be completely broadcastable. PiperOrigin-RevId: 192832919
Diffstat (limited to 'tensorflow/contrib/linalg')
-rw-r--r--tensorflow/contrib/linalg/BUILD2
-rw-r--r--tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py67
2 files changed, 3 insertions, 66 deletions
diff --git a/tensorflow/contrib/linalg/BUILD b/tensorflow/contrib/linalg/BUILD
index a7812f74d1..8b7ff75ba5 100644
--- a/tensorflow/contrib/linalg/BUILD
+++ b/tensorflow/contrib/linalg/BUILD
@@ -58,6 +58,6 @@ cuda_py_test(
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
],
- shard_count = 4,
+ shard_count = 5,
tags = ["noasan"],
)
diff --git a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py
index cc1a047d6a..e7407ede11 100644
--- a/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py
+++ b/tensorflow/contrib/linalg/python/kernel_tests/linear_operator_block_diag_test.py
@@ -76,6 +76,8 @@ class SquareLinearOperatorBlockDiagTest(
build_info((1, 1)),
build_info((1, 3, 3)),
build_info((5, 5), blocks=[(2, 2), (3, 3)]),
+ build_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]),
+ build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
]
def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
@@ -184,70 +186,5 @@ class SquareLinearOperatorBlockDiagTest(
block_diag.LinearOperatorBlockDiag([])
-# This test is for blocks with different batch dimensions.
-# LinearOperatorFullMatrix doesn't broadcast matmul/solve.
-class SquareDiagLinearOperatorBlockDiagTest(
- linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
- """Most tests done in the base class LinearOperatorDerivedClassTest."""
-
- def setUp(self):
- # Increase from 1e-6 to 1e-4
- self._atol[dtypes.float32] = 1e-4
- self._atol[dtypes.complex64] = 1e-4
- self._rtol[dtypes.float32] = 1e-4
- self._rtol[dtypes.complex64] = 1e-4
-
- @property
- def _operator_build_infos(self):
- build_info = linear_operator_test_util.OperatorBuildInfo
- return [
- build_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]),
- build_info((2, 1, 6, 6), blocks=[(2, 1, 2, 2), (1, 1, 4, 4)]),
- ]
-
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
- shape = list(build_info.shape)
- expected_blocks = (
- build_info.__dict__["blocks"] if "blocks" in build_info.__dict__
- else [shape])
- diag_matrices = [
- linear_operator_test_util.random_uniform(
- shape=block_shape[:-1], minval=1., maxval=20., dtype=dtype)
- for block_shape in expected_blocks
- ]
-
- if use_placeholder:
- diag_matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in expected_blocks
- ]
- diag_matrices = self.evaluate(diag_matrices)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorDiag(m_ph) for m_ph in diag_matrices_ph])
- feed_dict = {m_ph: m for (m_ph, m) in zip(
- diag_matrices_ph, diag_matrices)}
- else:
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorDiag(m) for m in diag_matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
-
- # Broadcast the shapes.
- expected_shape = list(build_info.shape)
-
- matrices = linear_operator_util.broadcast_matrix_batch_dims(
- [array_ops.matrix_diag(diag_block) for diag_block in diag_matrices])
-
- block_diag_dense = _block_diag_dense(expected_shape, matrices)
- if not use_placeholder:
- block_diag_dense.set_shape(
- expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])
-
- return operator, block_diag_dense, feed_dict
-
-
if __name__ == "__main__":
test.main()