From 5a7e26dddf0d18dad9c391787f3a9332fabf6c68 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Thu, 21 Jan 2016 08:31:20 -0800 Subject: Fix and re-enable tests for matrix_triangular_solve op. Change: 112688071 --- .../matrix_triangular_solve_op_test.py | 126 ++++++++++----------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py index 549b61d7da..933934ef52 100644 --- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py +++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py @@ -29,95 +29,89 @@ import tensorflow as tf class MatrixTriangularSolveOpTest(tf.test.TestCase): - def _verifySolve(self, x, y, lower=True): + def _verifySolve(self, x, y, lower=True, batch_dims=None): for np_type in [np.float32, np.float64]: a = x.astype(np_type) b = y.astype(np_type) + # For numpy.solve we have to explicitly zero out the strictly + # upper or lower triangle. + if lower and a.size > 0: + a_np = np.tril(a) + elif a.size > 0: + a_np = np.triu(a) + else: + a_np = a + if batch_dims is not None: + a = np.tile(a, batch_dims + [1, 1]) + a_np = np.tile(a_np, batch_dims + [1, 1]) + b = np.tile(b, batch_dims + [1, 1]) with self.test_session(): if a.ndim == 2: - tf_ans = tf.matrix_triangular_solve(a, b, lower=lower) + tf_ans = tf.matrix_triangular_solve(a, b, lower=lower).eval() else: - tf_ans = tf.batch_matrix_triangular_solve(a, b, lower=lower) - out = tf_ans.eval() - if lower: - np_ans = np.linalg.solve(np.tril(a), b) - else: - np_ans = np.linalg.solve(np.triu(a), b) - self.assertEqual(np_ans.shape, out.shape) - self.assertAllClose(np_ans, out) - - def DISABLEDtestBasicLower(self): - # 2x2 matrices, 2x1 right-hand side. - matrix0 = np.array([[1., 2.], [3., 4.]]) + tf_ans = tf.batch_matrix_triangular_solve(a, b, lower=lower).eval() + np_ans = np.linalg.solve(a_np, b) + self.assertEqual(np_ans.shape, tf_ans.shape) + self.assertAllClose(np_ans, tf_ans) + + def testSolve(self): + # 2x2 matrices, single right-hand side. + matrix = np.array([[1., 2.], [3., 4.]]) rhs0 = np.array([[1.], [1.]]) - self._verifySolve(matrix0, rhs0) - - # 2x2 matrices, 2x3 right-hand sides. - matrix1 = np.array([[1., 2.], [3., 4.]]) - matrix2 = np.array([[1., 3.], [3., 5.]]) + self._verifySolve(matrix, rhs0, lower=True) + self._verifySolve(matrix, rhs0, lower=False) + # 2x2 matrices, 3 right-hand sides. rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]) - rhs2 = np.array([[1., 1., 1.], [2., 2., 2.]]) - self._verifySolve(matrix1, rhs1) - self._verifySolve(matrix2, rhs2) - # A multidimensional batch of 2x2 matrices and 2x3 right-hand sides. - matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims( - matrix2, 0)]) - matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) - rhs_batch = np.concatenate([np.expand_dims(rhs1, 0), np.expand_dims(rhs2, 0) - ]) - rhs_batch = np.tile(rhs_batch, [2, 3, 1, 1]) - self._verifySolve(matrix_batch, rhs_batch) + self._verifySolve(matrix, rhs1, lower=True) + self._verifySolve(matrix, rhs1, lower=False) - def DISABLEDtestBasicUpper(self): - # 2x2 matrices, 2x1 right-hand side. - matrix0 = np.array([[1., 2.], [3., 4.]]) - rhs0 = np.array([[1.], [1.]]) - self._verifySolve(matrix0, rhs0, lower=False) - - # 2x2 matrices, 2x3 right-hand sides. - matrix1 = np.array([[1., 2.], [3., 4.]]) - matrix2 = np.array([[1., 3.], [3., 5.]]) - rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]) - rhs2 = np.array([[1., 1., 1.], [2., 2., 2.]]) - self._verifySolve(matrix1, rhs1, lower=False) - self._verifySolve(matrix2, rhs2, lower=False) - # A multidimensional batch of 2x2 matrices and 2x3 right-hand sides. - matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims( - matrix2, 0)]) - matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1]) - rhs_batch = np.concatenate([np.expand_dims(rhs1, 0), np.expand_dims(rhs2, 0) - ]) - rhs_batch = np.tile(rhs_batch, [2, 3, 1, 1]) - self._verifySolve(matrix_batch, rhs_batch, lower=False) + def testSolveBatch(self): + matrix = np.array([[1., 2.], [3., 4.]]) + rhs = np.array([[1., 0., 1.], [0., 1., 1.]]) + # Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides. + self._verifySolve(matrix, rhs, lower=True, batch_dims=[2, 3]) + # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides. + self._verifySolve(matrix, rhs, lower=False, batch_dims=[3, 2]) def testNonSquareMatrix(self): - # When the solve of a non-square matrix is attempted we should return - # an error + # A non-square matrix should cause an error. + matrix = np.array([[1., 2., 3.], [3., 4., 5.]]) with self.test_session(): with self.assertRaises(ValueError): - matrix = tf.constant([[1., 2., 3.], [3., 4., 5.]]) - tf.matrix_triangular_solve(matrix, matrix) + self._verifySolve(matrix, matrix) + with self.assertRaises(ValueError): + self._verifySolve(matrix, matrix, batch_dims=[2, 3]) def testWrongDimensions(self): - # The matrix and rhs should have the same number of rows as the + # The matrix should have the same number of rows as the # right-hand sides. + matrix = np.array([[1., 0.], [0., 1.]]) + rhs = np.array([[1., 0.]]) with self.test_session(): - matrix = tf.constant([[1., 0.], [0., 1.]]) - rhs = tf.constant([[1., 0.]]) with self.assertRaises(ValueError): - tf.matrix_triangular_solve(matrix, rhs) + self._verifySolve(matrix, rhs) + with self.assertRaises(ValueError): + self._verifySolve(matrix, rhs, batch_dims=[2, 3]) def testNotInvertible(self): # The input should be invertible. + # The matrix is singular because it has a zero on the diagonal. + singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]]) with self.test_session(): with self.assertRaisesOpError("Input matrix is not invertible."): - # The matrix has a zero on the diagonal. - matrix = tf.constant([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]]) - tf.matrix_triangular_solve(matrix, matrix).eval() - - def DISABLEDtestEmpty(self): - self._verifySolve(np.empty([0, 2, 2]), np.empty([0, 2, 2])) - self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0])) + self._verifySolve(singular_matrix, singular_matrix) + with self.assertRaisesOpError("Input matrix is not invertible."): + self._verifySolve(singular_matrix, singular_matrix, batch_dims=[2, 3]) + + def testEmpty(self): + self._verifySolve(np.empty([0, 2, 2]), np.empty([0, 2, 2]), lower=True) + self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=True) + self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=False) + self._verifySolve( + np.empty([2, 0, 0]), + np.empty([2, 0, 0]), + lower=True, + batch_dims=[3, 2]) if __name__ == "__main__": -- cgit v1.2.3