# Copyright 2018 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Tests for tensorflow.ops.tf.MatrixTriangularSolve.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import itertools import numpy as np from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes from tensorflow.python.ops import array_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.platform import test def MakePlaceholder(x): return array_ops.placeholder(dtypes.as_dtype(x.dtype), shape=x.shape) class MatrixTriangularSolveOpTest(xla_test.XLATestCase): # MatrixTriangularSolve defined for float64, float32, complex64, complex128 # (https://www.tensorflow.org/api_docs/python/tf/matrix_triangular_solve) @property def float_types(self): return set(super(MatrixTriangularSolveOpTest, self).float_types).intersection( (np.float64, np.float32, np.complex64, np.complex128)) def _VerifyTriangularSolveBase(self, sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol): feed_dict = {placeholder_a: a, placeholder_ca: clean_a, placeholder_b: b} verification_np = sess.run(verification, feed_dict) self.assertAllClose(b, verification_np, atol=atol) def _VerifyTriangularSolve(self, a, b, lower, adjoint, atol): clean_a = np.tril(a) if lower else np.triu(a) with self.cached_session() as sess: placeholder_a = MakePlaceholder(a) placeholder_ca = MakePlaceholder(clean_a) placeholder_b = MakePlaceholder(b) with self.test_scope(): x = linalg_ops.matrix_triangular_solve( placeholder_a, placeholder_b, lower=lower, adjoint=adjoint) verification = math_ops.matmul(placeholder_ca, x, adjoint_a=adjoint) self._VerifyTriangularSolveBase(sess, placeholder_a, placeholder_ca, placeholder_b, a, clean_a, b, verification, atol) def _VerifyTriangularSolveCombo(self, a, b, atol=1e-4): transp = lambda x: np.swapaxes(x, -1, -2) for lower, adjoint in itertools.product([True, False], repeat=2): self._VerifyTriangularSolve( a if lower else transp(a), b, lower, adjoint, atol) def testBasic(self): rng = np.random.RandomState(0) a = np.tril(rng.randn(5, 5)) b = rng.randn(5, 7) for dtype in self.float_types: self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) def testBasicNotActuallyTriangular(self): rng = np.random.RandomState(0) a = rng.randn(5, 5) # the `a` matrix is not lower-triangular b = rng.randn(5, 7) for dtype in self.float_types: self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) def testBasicComplexDtypes(self): rng = np.random.RandomState(0) a = np.tril(rng.randn(5, 5) + rng.randn(5, 5) * 1j) b = rng.randn(5, 7) + rng.randn(5, 7) * 1j for dtype in self.complex_types: self._VerifyTriangularSolveCombo(a.astype(dtype), b.astype(dtype)) def testBatch(self): rng = np.random.RandomState(0) shapes = [((4, 3, 3), (4, 3, 5)), ((1, 2, 2), (1, 2, 1)), ((1, 1, 1), (1, 1, 2)), ((2, 3, 4, 4), (2, 3, 4, 1))] tuples = itertools.product(self.float_types, shapes) for dtype, (a_shape, b_shape) in tuples: n = a_shape[-1] a = np.tril(rng.rand(*a_shape) - 0.5) / (2.0 * n) + np.eye(n) b = rng.randn(*b_shape) self._VerifyTriangularSolveCombo( a.astype(dtype), b.astype(dtype), atol=1e-3) def testLarge(self): n = 1024 rng = np.random.RandomState(0) a = np.tril(rng.rand(n, n) - 0.5) / (2.0 * n) + np.eye(n) b = rng.randn(n, n) self._VerifyTriangularSolve( a.astype(np.float32), b.astype(np.float32), True, False, 1e-4) def testNonSquareCoefficientMatrix(self): rng = np.random.RandomState(0) for dtype in self.float_types: a = rng.randn(3, 4).astype(dtype) b = rng.randn(4, 4).astype(dtype) with self.assertRaises(ValueError): linalg_ops.matrix_triangular_solve(a, b) with self.assertRaises(ValueError): linalg_ops.matrix_triangular_solve(a, b) def testWrongDimensions(self): randn = np.random.RandomState(0).randn for dtype in self.float_types: lhs = constant_op.constant(randn(3, 3), dtype=dtype) rhs = constant_op.constant(randn(4, 3), dtype=dtype) with self.assertRaises(ValueError): linalg_ops.matrix_triangular_solve(lhs, rhs) with self.assertRaises(ValueError): linalg_ops.matrix_triangular_solve(lhs, rhs) if __name__ == "__main__": test.main()