aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <nobody@tensorflow.org>2016-01-21 08:31:20 -0800
committerGravatar Vijay Vasudevan <vrv@google.com>2016-01-21 17:59:20 -0800
commit5a7e26dddf0d18dad9c391787f3a9332fabf6c68 (patch)
tree9fd21a1858fcf466ef6c0c601d80895d581ae94f
parent014ef47a22fa9aea972854c5f99232d46e55c3a5 (diff)
Fix and re-enable tests for matrix_triangular_solve op.
Change: 112688071
-rw-r--r--tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py126
1 files changed, 60 insertions, 66 deletions
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index 549b61d7da..933934ef52 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -29,95 +29,89 @@ import tensorflow as tf
class MatrixTriangularSolveOpTest(tf.test.TestCase):
- def _verifySolve(self, x, y, lower=True):
+ def _verifySolve(self, x, y, lower=True, batch_dims=None):
for np_type in [np.float32, np.float64]:
a = x.astype(np_type)
b = y.astype(np_type)
+ # For numpy.solve we have to explicitly zero out the strictly
+ # upper or lower triangle.
+ if lower and a.size > 0:
+ a_np = np.tril(a)
+ elif a.size > 0:
+ a_np = np.triu(a)
+ else:
+ a_np = a
+ if batch_dims is not None:
+ a = np.tile(a, batch_dims + [1, 1])
+ a_np = np.tile(a_np, batch_dims + [1, 1])
+ b = np.tile(b, batch_dims + [1, 1])
with self.test_session():
if a.ndim == 2:
- tf_ans = tf.matrix_triangular_solve(a, b, lower=lower)
+ tf_ans = tf.matrix_triangular_solve(a, b, lower=lower).eval()
else:
- tf_ans = tf.batch_matrix_triangular_solve(a, b, lower=lower)
- out = tf_ans.eval()
- if lower:
- np_ans = np.linalg.solve(np.tril(a), b)
- else:
- np_ans = np.linalg.solve(np.triu(a), b)
- self.assertEqual(np_ans.shape, out.shape)
- self.assertAllClose(np_ans, out)
-
- def DISABLEDtestBasicLower(self):
- # 2x2 matrices, 2x1 right-hand side.
- matrix0 = np.array([[1., 2.], [3., 4.]])
+ tf_ans = tf.batch_matrix_triangular_solve(a, b, lower=lower).eval()
+ np_ans = np.linalg.solve(a_np, b)
+ self.assertEqual(np_ans.shape, tf_ans.shape)
+ self.assertAllClose(np_ans, tf_ans)
+
+ def testSolve(self):
+ # 2x2 matrices, single right-hand side.
+ matrix = np.array([[1., 2.], [3., 4.]])
rhs0 = np.array([[1.], [1.]])
- self._verifySolve(matrix0, rhs0)
-
- # 2x2 matrices, 2x3 right-hand sides.
- matrix1 = np.array([[1., 2.], [3., 4.]])
- matrix2 = np.array([[1., 3.], [3., 5.]])
+ self._verifySolve(matrix, rhs0, lower=True)
+ self._verifySolve(matrix, rhs0, lower=False)
+ # 2x2 matrices, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
- rhs2 = np.array([[1., 1., 1.], [2., 2., 2.]])
- self._verifySolve(matrix1, rhs1)
- self._verifySolve(matrix2, rhs2)
- # A multidimensional batch of 2x2 matrices and 2x3 right-hand sides.
- matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims(
- matrix2, 0)])
- matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
- rhs_batch = np.concatenate([np.expand_dims(rhs1, 0), np.expand_dims(rhs2, 0)
- ])
- rhs_batch = np.tile(rhs_batch, [2, 3, 1, 1])
- self._verifySolve(matrix_batch, rhs_batch)
+ self._verifySolve(matrix, rhs1, lower=True)
+ self._verifySolve(matrix, rhs1, lower=False)
- def DISABLEDtestBasicUpper(self):
- # 2x2 matrices, 2x1 right-hand side.
- matrix0 = np.array([[1., 2.], [3., 4.]])
- rhs0 = np.array([[1.], [1.]])
- self._verifySolve(matrix0, rhs0, lower=False)
-
- # 2x2 matrices, 2x3 right-hand sides.
- matrix1 = np.array([[1., 2.], [3., 4.]])
- matrix2 = np.array([[1., 3.], [3., 5.]])
- rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
- rhs2 = np.array([[1., 1., 1.], [2., 2., 2.]])
- self._verifySolve(matrix1, rhs1, lower=False)
- self._verifySolve(matrix2, rhs2, lower=False)
- # A multidimensional batch of 2x2 matrices and 2x3 right-hand sides.
- matrix_batch = np.concatenate([np.expand_dims(matrix1, 0), np.expand_dims(
- matrix2, 0)])
- matrix_batch = np.tile(matrix_batch, [2, 3, 1, 1])
- rhs_batch = np.concatenate([np.expand_dims(rhs1, 0), np.expand_dims(rhs2, 0)
- ])
- rhs_batch = np.tile(rhs_batch, [2, 3, 1, 1])
- self._verifySolve(matrix_batch, rhs_batch, lower=False)
+ def testSolveBatch(self):
+ matrix = np.array([[1., 2.], [3., 4.]])
+ rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
+ # Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
+ self._verifySolve(matrix, rhs, lower=True, batch_dims=[2, 3])
+ # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
+ self._verifySolve(matrix, rhs, lower=False, batch_dims=[3, 2])
def testNonSquareMatrix(self):
- # When the solve of a non-square matrix is attempted we should return
- # an error
+ # A non-square matrix should cause an error.
+ matrix = np.array([[1., 2., 3.], [3., 4., 5.]])
with self.test_session():
with self.assertRaises(ValueError):
- matrix = tf.constant([[1., 2., 3.], [3., 4., 5.]])
- tf.matrix_triangular_solve(matrix, matrix)
+ self._verifySolve(matrix, matrix)
+ with self.assertRaises(ValueError):
+ self._verifySolve(matrix, matrix, batch_dims=[2, 3])
def testWrongDimensions(self):
- # The matrix and rhs should have the same number of rows as the
+ # The matrix should have the same number of rows as the
# right-hand sides.
+ matrix = np.array([[1., 0.], [0., 1.]])
+ rhs = np.array([[1., 0.]])
with self.test_session():
- matrix = tf.constant([[1., 0.], [0., 1.]])
- rhs = tf.constant([[1., 0.]])
with self.assertRaises(ValueError):
- tf.matrix_triangular_solve(matrix, rhs)
+ self._verifySolve(matrix, rhs)
+ with self.assertRaises(ValueError):
+ self._verifySolve(matrix, rhs, batch_dims=[2, 3])
def testNotInvertible(self):
# The input should be invertible.
+ # The matrix is singular because it has a zero on the diagonal.
+ singular_matrix = np.array([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
with self.test_session():
with self.assertRaisesOpError("Input matrix is not invertible."):
- # The matrix has a zero on the diagonal.
- matrix = tf.constant([[1., 0., -1.], [-1., 0., 1.], [0., -1., 1.]])
- tf.matrix_triangular_solve(matrix, matrix).eval()
-
- def DISABLEDtestEmpty(self):
- self._verifySolve(np.empty([0, 2, 2]), np.empty([0, 2, 2]))
- self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]))
+ self._verifySolve(singular_matrix, singular_matrix)
+ with self.assertRaisesOpError("Input matrix is not invertible."):
+ self._verifySolve(singular_matrix, singular_matrix, batch_dims=[2, 3])
+
+ def testEmpty(self):
+ self._verifySolve(np.empty([0, 2, 2]), np.empty([0, 2, 2]), lower=True)
+ self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=True)
+ self._verifySolve(np.empty([2, 0, 0]), np.empty([2, 0, 0]), lower=False)
+ self._verifySolve(
+ np.empty([2, 0, 0]),
+ np.empty([2, 0, 0]),
+ lower=True,
+ batch_dims=[3, 2])
if __name__ == "__main__":