aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/cholesky_op_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/cholesky_op_test.py')
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py74
1 files changed, 74 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
new file mode 100644
index 0000000000..17e8d116be
--- /dev/null
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -0,0 +1,74 @@
+"""Tests for tensorflow.ops.tf.Cholesky."""
+import tensorflow.python.platform
+
+import numpy as np
+import tensorflow as tf
+
+
+class CholeskyOpTest(tf.test.TestCase):
+
+ def _verifyCholesky(self, x):
+ with self.test_session() as sess:
+ # Verify that LL^T == x.
+ if x.ndim == 2:
+ chol = tf.cholesky(x)
+ verification = tf.matmul(chol,
+ chol,
+ transpose_a=False,
+ transpose_b=True)
+ else:
+ chol = tf.batch_cholesky(x)
+ verification = tf.batch_matmul(chol, chol, adj_x=False, adj_y=True)
+ chol_np, verification_np = sess.run([chol, verification])
+ self.assertAllClose(x, verification_np)
+ self.assertShapeEqual(x, chol)
+ # Check that the cholesky is lower triangular, and has positive diagonal
+ # elements.
+ if chol_np.shape[-1] > 0:
+ chol_reshaped = np.reshape(chol_np, (-1, chol_np.shape[-2],
+ chol_np.shape[-1]))
+ for chol_matrix in chol_reshaped:
+ self.assertAllClose(chol_matrix, np.tril(chol_matrix))
+ self.assertTrue((np.diag(chol_matrix) > 0.0).all())
+
+ def testBasic(self):
+ self._verifyCholesky(np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]))
+
+ def testBatch(self):
+ simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
+ self._verifyCholesky(simple_array)
+ self._verifyCholesky(np.vstack((simple_array, simple_array)))
+ odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
+ self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
+
+ # Generate random positive-definite matrices.
+ matrices = np.random.rand(10, 5, 5)
+ for i in xrange(10):
+ matrices[i] = np.dot(matrices[i].T, matrices[i])
+ self._verifyCholesky(matrices)
+
+ def testNonSquareMatrix(self):
+ with self.assertRaises(ValueError):
+ tf.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
+
+ def testWrongDimensions(self):
+ tensor3 = tf.constant([1., 2.])
+ with self.assertRaises(ValueError):
+ tf.cholesky(tensor3)
+
+ def testNotInvertible(self):
+ # The input should be invertible.
+ with self.test_session():
+ with self.assertRaisesOpError("LLT decomposition was not successful. The "
+ "input might not be valid."):
+ # All rows of the matrix below add to zero
+ self._verifyCholesky(np.array([[1., -1., 0.], [-1., 1., -1.], [0., -1.,
+ 1.]]))
+
+ def testEmpty(self):
+ self._verifyCholesky(np.empty([0, 2, 2]))
+ self._verifyCholesky(np.empty([2, 0, 0]))
+
+
+if __name__ == "__main__":
+ tf.test.main()