aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-05-31 04:23:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-05-31 04:27:09 -0700
commit473a590c9cd26cdde1e77117778e3fd50a36d7df (patch)
tree61215c75a17e27998fd3d88611096ef93164fa1d
parent2d1860859a812437d5c20fa3bf75e6e989fbbb87 (diff)
Allow complex valued input for Cholesky decomposition.
PiperOrigin-RevId: 157572536
-rw-r--r--tensorflow/core/kernels/cholesky_op.cc13
-rw-r--r--tensorflow/core/kernels/matrix_diag_op.cc4
-rw-r--r--tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op.cc4
-rw-r--r--tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc2
-rw-r--r--tensorflow/core/kernels/matrix_triangular_solve_op.cc16
-rw-r--r--tensorflow/core/ops/linalg_ops.cc17
-rw-r--r--tensorflow/python/kernel_tests/BUILD2
-rw-r--r--tensorflow/python/kernel_tests/cholesky_op_test.py212
-rw-r--r--tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py52
-rw-r--r--tensorflow/python/ops/linalg_grad.py19
11 files changed, 305 insertions, 38 deletions
diff --git a/tensorflow/core/kernels/cholesky_op.cc b/tensorflow/core/kernels/cholesky_op.cc
index 10595faf4b..5c7102f6f6 100644
--- a/tensorflow/core/kernels/cholesky_op.cc
+++ b/tensorflow/core/kernels/cholesky_op.cc
@@ -14,8 +14,7 @@ limitations under the License.
==============================================================================*/
// See docs in ../ops/linalg_ops.cc.
-// TODO(konstantinos): Enable complex inputs. This will require additional tests
-// and OP_REQUIRES.
+
#if GOOGLE_CUDA
#define EIGEN_USE_GPU
#endif // GOOGLE_CUDA
@@ -85,8 +84,10 @@ namespace functor {
typename TTypes<T, 3>::Tensor output); \
extern template struct MatrixBandPart<GPUDevice, T>;
-TF_CALL_float(DECLARE_GPU_SPEC);
-TF_CALL_double(DECLARE_GPU_SPEC);
+TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+TF_CALL_complex64(DECLARE_GPU_SPEC);
+TF_CALL_complex128(DECLARE_GPU_SPEC);
+
} // namespace functor
template <class Scalar>
@@ -171,11 +172,15 @@ class CholeskyOpGpu : public AsyncOpKernel {
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<float>), float);
REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<double>), double);
+REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex64>), complex64);
+REGISTER_LINALG_OP_GPU("Cholesky", (CholeskyOpGpu<complex128>), complex128);
#endif // GOOGLE_CUDA
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("Cholesky", (CholeskyOp<double>), double);
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex64>), complex64);
+REGISTER_LINALG_OP("Cholesky", (CholeskyOp<complex128>), complex128);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<float>), float);
REGISTER_LINALG_OP("BatchCholesky", (CholeskyOp<double>), double);
diff --git a/tensorflow/core/kernels/matrix_diag_op.cc b/tensorflow/core/kernels/matrix_diag_op.cc
index 58b1b3a5e1..bc193357ad 100644
--- a/tensorflow/core/kernels/matrix_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_diag_op.cc
@@ -187,6 +187,8 @@ namespace functor {
extern template struct MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+TF_CALL_complex64(DECLARE_GPU_SPEC);
+TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
@@ -199,6 +201,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
Name("MatrixDiagPart").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixDiagPartOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_DIAG_GPU);
+TF_CALL_complex64(REGISTER_MATRIX_DIAG_GPU);
+TF_CALL_complex128(REGISTER_MATRIX_DIAG_GPU);
#undef REGISTER_MATRIX_DIAG_GPU
// Registration of the deprecated kernel.
diff --git a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
index 8d3bbb048e..14774d5404 100644
--- a/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_diag_op_gpu.cu.cc
@@ -31,6 +31,8 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixDiagPart<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
+TF_CALL_complex64(DEFINE_GPU_SPEC);
+TF_CALL_complex128(DEFINE_GPU_SPEC);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/matrix_set_diag_op.cc b/tensorflow/core/kernels/matrix_set_diag_op.cc
index bc5e6dba38..cbb2b68b7f 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op.cc
@@ -147,6 +147,8 @@ namespace functor {
extern template struct MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
+TF_CALL_complex64(DECLARE_GPU_SPEC);
+TF_CALL_complex128(DECLARE_GPU_SPEC);
} // namespace functor
@@ -156,6 +158,8 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC);
Name("MatrixSetDiag").Device(DEVICE_GPU).TypeConstraint<type>("T"), \
MatrixSetDiagOp<GPUDevice, type>);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_MATRIX_SET_DIAG_GPU);
+TF_CALL_complex64(REGISTER_MATRIX_SET_DIAG_GPU);
+TF_CALL_complex128(REGISTER_MATRIX_SET_DIAG_GPU);
#undef REGISTER_MATRIX_SET_DIAG_GPU
// Registration of the deprecated kernel.
diff --git a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
index ba3e475ee8..bd097ff328 100644
--- a/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/matrix_set_diag_op_gpu.cu.cc
@@ -29,6 +29,8 @@ typedef Eigen::GpuDevice GPUDevice;
template struct functor::MatrixSetDiag<GPUDevice, T>;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPEC);
+TF_CALL_complex64(DEFINE_GPU_SPEC);
+TF_CALL_complex128(DEFINE_GPU_SPEC);
} // end namespace tensorflow
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index ebc2333080..953f37fa02 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -97,8 +97,9 @@ class MatrixTriangularSolveOp : public LinearAlgebraOp<Scalar> {
// an empty set of equation as the empty matrix.
return;
}
- const Scalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
- OP_REQUIRES(context, min_abs_pivot > Scalar(0),
+ using RealScalar = typename Base::RealScalar;
+ const RealScalar min_abs_pivot = matrix.diagonal().cwiseAbs().minCoeff();
+ OP_REQUIRES(context, min_abs_pivot > RealScalar(0),
errors::InvalidArgument("Input matrix is not invertible."));
if (lower_) {
auto triangle = matrix.template triangularView<Eigen::Lower>();
@@ -128,6 +129,10 @@ REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
(MatrixTriangularSolveOp<double>), double);
+REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
+ (MatrixTriangularSolveOp<complex64>), complex64);
+REGISTER_LINALG_OP_CPU("MatrixTriangularSolve",
+ (MatrixTriangularSolveOp<complex128>), complex128);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOp<float>), float);
REGISTER_LINALG_OP_CPU("BatchMatrixTriangularSolve",
@@ -215,7 +220,8 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
}
if (adjoint_) {
- transpose_matrix = perftools::gputools::blas::Transpose::kTranspose;
+ transpose_matrix =
+ perftools::gputools::blas::Transpose::kConjugateTranspose;
} else {
transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
}
@@ -249,6 +255,10 @@ REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<double>), double);
+REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
+ (MatrixTriangularSolveOpGPU<complex64>), complex64);
+REGISTER_LINALG_OP_GPU("MatrixTriangularSolve",
+ (MatrixTriangularSolveOpGPU<complex128>), complex128);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
(MatrixTriangularSolveOpGPU<float>), float);
REGISTER_LINALG_OP_GPU("BatchMatrixTriangularSolve",
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index 678ebb10ca..6e1f2dc052 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -245,16 +245,25 @@ Equivalent to np.linalg.inv
REGISTER_OP("Cholesky")
.Input("input: T")
.Output("output: T")
- .Attr("T: {double, float}")
+ .Attr("T: {double, float, complex64, complex128}")
.SetShapeFn(BatchUnchangedSquareShapeFn)
.Doc(R"doc(
Computes the Cholesky decomposition of one or more square matrices.
The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
-form square matrices, with the same constraints as the single matrix Cholesky
-decomposition above. The output is a tensor of the same shape as the input
+form square matrices.
+
+The input has to be symmetric and positive definite. Only the lower-triangular
+part of the input will be used for this operation. The upper-triangular part
+will not be read.
+
+The output is a tensor of the same shape as the input
containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+**Note**: The gradient computation on GPU is faster for large matrices but
+not for large batch dimensions when the submatrices are small. In this
+case it might be faster to use the CPU.
+
input: Shape is `[..., M, M]`.
output: Shape is `[..., M, M]`.
)doc");
@@ -373,7 +382,7 @@ REGISTER_OP("MatrixTriangularSolve")
.Output("output: T")
.Attr("lower: bool = True")
.Attr("adjoint: bool = False")
- .Attr("T: {double, float}")
+ .Attr("T: {double, float, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return MatrixSolveShapeFn(c, true /* square (*/);
})
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index f9da3beb04..2b9b10112e 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -120,7 +120,7 @@ tf_py_test(
],
)
-tf_py_test(
+cuda_py_test(
name = "cholesky_op_test",
size = "small",
srcs = ["cholesky_op_test.py"],
diff --git a/tensorflow/python/kernel_tests/cholesky_op_test.py b/tensorflow/python/kernel_tests/cholesky_op_test.py
index d95200ec92..9a1c918b15 100644
--- a/tensorflow/python/kernel_tests/cholesky_op_test.py
+++ b/tensorflow/python/kernel_tests/cholesky_op_test.py
@@ -21,16 +21,70 @@ from __future__ import print_function
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.python.client import session
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes as dtypes_lib
+from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging
+# Different gradient implementations for benchmark purposes
+def SpecializedGrad(l, grad):
+ return gen_linalg_ops.cholesky_grad(l, grad)
+
+
+def _GradWithInverseL(l, l_inverse, grad):
+ middle = math_ops.matmul(l, grad, adjoint_a=True)
+ middle = array_ops.matrix_set_diag(middle,
+ 0.5 * array_ops.matrix_diag_part(middle))
+ middle = array_ops.matrix_band_part(middle, -1, 0)
+ grad_a = math_ops.matmul(
+ math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
+ grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
+ return grad_a * 0.5
+
+
+def TriAngSolveCompositeGrad(l, grad):
+ # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
+
+ # Compute ((l^{H} @ grad) * (tril(ones)-1/2*eye)) = middle
+ middle = math_ops.matmul(l, grad, adjoint_a=True)
+ middle = array_ops.matrix_set_diag(middle,
+ 0.5 * array_ops.matrix_diag_part(middle))
+ middle = array_ops.matrix_band_part(middle, -1, 0)
+
+ # Compute l^{-H} @ middle = z
+ l_inverse_middle = linalg_ops.matrix_triangular_solve(l, middle, adjoint=True)
+
+ # We need to compute z @ l^{-1}. With matrix_triangular_solve we
+ # actually compute l^{-H} @ z^{H} = grad. Since we later add grad^{H}
+ # we can ommit the conjugate transpose here.
+ z_h = math_ops.conj(array_ops.matrix_transpose(l_inverse_middle))
+ grad_a = linalg_ops.matrix_triangular_solve(l, z_h, adjoint=True)
+ grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
+ return grad_a * 0.5
+
+
+def MatrixInverseCompositeGrad(l, grad):
+ l_inverse = linalg_ops.matrix_inverse(l)
+ return _GradWithInverseL(l, l_inverse, grad)
+
+
+def TriAngInvCompositeGrad(l, grad):
+ num_rows = array_ops.shape(l)[-1]
+ batch_shape = array_ops.shape(l)[:-2]
+ l_inverse = linalg_ops.matrix_triangular_solve(
+ l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
+ return _GradWithInverseL(l, l_inverse, grad)
+
+
class CholeskyOpTest(test.TestCase):
def _verifyCholeskyBase(self, sess, x, chol, verification):
@@ -54,9 +108,14 @@ class CholeskyOpTest(test.TestCase):
self._verifyCholeskyBase(sess, x, chol, verification)
def testBasic(self):
+ data = np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]])
for dtype in (np.float32, np.float64):
- self._verifyCholesky(
- np.array([[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]).astype(dtype))
+ self._verifyCholesky(data.astype(dtype))
+ for dtype in (np.complex64, np.complex128):
+ complex_data = np.tril(1j * data, -1).astype(dtype)
+ complex_data += np.triu(-1j * data, 1).astype(dtype)
+ complex_data += data
+ self._verifyCholesky(complex_data)
def testBatch(self):
simple_array = np.array([[[1., 0.], [0., 5.]]]) # shape (1, 2, 2)
@@ -65,12 +124,18 @@ class CholeskyOpTest(test.TestCase):
odd_sized_array = np.array([[[4., -1., 2.], [-1., 6., 0], [2., 0., 5.]]])
self._verifyCholesky(np.vstack((odd_sized_array, odd_sized_array)))
- # Generate random positive-definite matrices.
+ # Generate random positive-definite matrices.
matrices = np.random.rand(10, 5, 5)
for i in xrange(10):
matrices[i] = np.dot(matrices[i].T, matrices[i])
self._verifyCholesky(matrices)
+ # Generate random complex valued positive-definite matrices.
+ matrices = np.random.rand(10, 5, 5) + 1j * np.random.rand(10, 5, 5)
+ for i in xrange(10):
+ matrices[i] = np.dot(matrices[i].T.conj(), matrices[i])
+ self._verifyCholesky(matrices)
+
def testNonSquareMatrix(self):
with self.assertRaises(ValueError):
linalg_ops.cholesky(np.array([[1., 2., 3.], [3., 4., 5.]]))
@@ -110,7 +175,14 @@ class CholeskyGradTest(test.TestCase):
def testSmallMatrices(self):
np.random.seed(0)
shapes = self.getShapes([1, 2, 10])
- self.runFiniteDifferences(shapes)
+ self.runFiniteDifferences(
+ shapes, dtypes=(dtypes_lib.float32, dtypes_lib.float64))
+
+ def testSmallMatricesComplex(self):
+ np.random.seed(0)
+ shapes = self.getShapes([1, 2, 10])
+ self.runFiniteDifferences(
+ shapes, dtypes=(dtypes_lib.complex64, dtypes_lib.complex128))
def testOneBlockMatrices(self):
np.random.seed(0)
@@ -132,25 +204,61 @@ class CholeskyGradTest(test.TestCase):
self.runFiniteDifferences(
shapes, dtypes=(dtypes_lib.float64,), scalarTest=True)
+ def testTwoBlockMatrixComplexFloat(self):
+ np.random.seed(0)
+ shapes = self.getShapes([2 * self._backprop_block_size + 1])
+ self.runFiniteDifferences(
+ shapes, dtypes=(dtypes_lib.complex64,), scalarTest=True)
+
+ def testTwoBlockMatrixComplexDouble(self):
+ np.random.seed(0)
+ shapes = self.getShapes([2 * self._backprop_block_size + 1])
+ self.runFiniteDifferences(
+ shapes, dtypes=(dtypes_lib.complex128,), scalarTest=True)
+
+ def testAgainstSpecialized(self):
+ np.random.seed(0)
+ data = np.random.randn(33, 33).astype(np.float32)
+ data = np.matmul(data, data.T)
+ grad_data = np.random.randn(*data.shape).astype(np.float32)
+
+ with ops.Graph().as_default(), self.test_session(use_gpu=False) as s:
+ x = constant_op.constant(data, dtypes_lib.float32)
+ chol = linalg_ops.cholesky(x)
+ composite_grad = gradients_impl.gradients(chol, x, grad_data)[0]
+ specialized_grad = SpecializedGrad(chol, grad_data)
+ reference, actual = s.run([specialized_grad, composite_grad])
+ self.assertAllClose(reference, actual)
+
def runFiniteDifferences(self,
shapes,
- dtypes=(dtypes_lib.float32, dtypes_lib.float64),
+ dtypes=(dtypes_lib.float32, dtypes_lib.float64,
+ dtypes_lib.complex64, dtypes_lib.complex128),
scalarTest=False):
- with self.test_session(use_gpu=False):
+ with self.test_session(use_gpu=True):
for shape in shapes:
for batch in False, True:
for dtype in dtypes:
if not scalarTest:
- x = constant_op.constant(
- np.random.randn(shape[0], shape[1]), dtype)
- tensor = math_ops.matmul(x, array_ops.transpose(x)) / shape[0]
+ data = np.random.randn(shape[0], shape[1])
+ if dtype.is_complex:
+ data = data.astype(np.complex64)
+ data += 1j * np.random.randn(shape[0], shape[1])
+ x = constant_op.constant(data, dtype)
+ tensor = math_ops.matmul(
+ x, math_ops.conj(array_ops.transpose(x))) / shape[0]
else:
# This is designed to be a faster test for larger matrices.
- x = constant_op.constant(np.random.randn(), dtype)
+ data = np.random.randn()
+ if dtype.is_complex:
+ data = np.complex64(data)
+ data += 1j * np.random.randn()
+ x = constant_op.constant(data, dtype)
R = constant_op.constant(
np.random.randn(shape[0], shape[1]), dtype)
e = math_ops.multiply(R, x)
- tensor = math_ops.matmul(e, array_ops.transpose(e)) / shape[0]
+ tensor = math_ops.matmul(
+ e, math_ops.conj(array_ops.transpose(e))) / shape[0]
# Inner-most matrices in tensor are positive definite.
if batch:
@@ -159,15 +267,87 @@ class CholeskyGradTest(test.TestCase):
y = linalg_ops.cholesky(tensor)
if scalarTest:
y = math_ops.reduce_mean(y)
- error = gradient_checker.compute_gradient_error(x,
- x._shape_as_list(),
- y,
- y._shape_as_list())
+ error = gradient_checker.compute_gradient_error(
+ x, x._shape_as_list(), y, y._shape_as_list())
tf_logging.info("error = %f", error)
if dtype == dtypes_lib.float64:
self.assertLess(error, 1e-5)
+ elif dtype == dtypes_lib.complex128:
+ self.assertLess(error, 5e-5)
else:
- self.assertLess(error, 3e-3)
+ self.assertLess(error, 5e-3)
+
+
+class CholeskyBenchmark(test.Benchmark):
+
+ sizes = [
+ (4, 4), (16, 16), (256, 256), (1024, 1024), (2048, 2048),
+ (513, 2, 2), (513, 8, 8), (4, 513, 2, 2)
+ ]
+
+ def _GenerateData(self, size):
+ batch_shape = size[:-2]
+ size = size[-2:]
+ assert size[0] == size[1]
+ n = size[0]
+ data = np.ones(size).astype(np.float32) / (2.0 * n) + np.diag(
+ np.ones(n).astype(np.float32))
+ return np.tile(data, batch_shape + (1, 1))
+
+ def benchmarkCholeskyOp(self):
+ for size in self.sizes:
+ data = self._GenerateData(size)
+
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/cpu:0"):
+ l = linalg_ops.cholesky(data)
+ self.run_op_benchmark(
+ sess, l,
+ min_iters=25,
+ name="cholesky_cpu_{size}".format(size=size))
+
+ if test.is_gpu_available(True):
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device("/gpu:0"):
+ l = linalg_ops.cholesky(data)
+ self.run_op_benchmark(
+ sess, l,
+ min_iters=25,
+ name="cholesky_gpu_{size}".format(size=size))
+
+ def benchmarkGradVariants(self):
+ def _BenchmarkGrad(grad_fn, name, device):
+ for size in self.sizes:
+ data = self._GenerateData(size)
+ l = np.linalg.cholesky(data)
+ grad_data = np.random.randn(*data.shape).astype(np.float32)
+ with ops.Graph().as_default(), \
+ session.Session() as sess, \
+ ops.device(device):
+ grad = grad_fn(l, grad_data)
+ self.run_op_benchmark(
+ sess, grad,
+ min_iters=25,
+ name="{name}_{dev}_{size}".format(
+ name=name, dev=grad.device, size=size))
+
+ if test.is_gpu_available(True):
+ _BenchmarkGrad(
+ MatrixInverseCompositeGrad, "composite_matrix_inverse", "/gpu:0")
+ _BenchmarkGrad(
+ TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/gpu:0")
+ _BenchmarkGrad(
+ TriAngSolveCompositeGrad, "composite_triangular_solve", "/gpu:0")
+
+ _BenchmarkGrad(
+ MatrixInverseCompositeGrad, "composite_matrix_inverse", "/cpu:0")
+ _BenchmarkGrad(
+ TriAngInvCompositeGrad, "composite_tri_ang_inverse", "/cpu:0")
+ _BenchmarkGrad(
+ TriAngSolveCompositeGrad, "composite_triangular_solve", "/cpu:0")
+ _BenchmarkGrad(SpecializedGrad, "specialized", "/cpu:0")
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
index 262c197480..33288392c0 100644
--- a/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
+++ b/tensorflow/python/kernel_tests/matrix_triangular_solve_op_test.py
@@ -28,7 +28,7 @@ from tensorflow.python.platform import test
class MatrixTriangularSolveOpTest(test.TestCase):
- def _verifySolveAllWays(self, x, y, batch_dims=None):
+ def _verifySolveAllWays(self, x, y, dtypes, batch_dims=None):
for lower in True, False:
for adjoint in True, False:
for use_placeholder in True, False:
@@ -38,7 +38,14 @@ class MatrixTriangularSolveOpTest(test.TestCase):
lower=lower,
adjoint=adjoint,
batch_dims=batch_dims,
- use_placeholder=use_placeholder)
+ use_placeholder=use_placeholder,
+ dtypes=dtypes)
+
+ def _verifySolveAllWaysReal(self, x, y, batch_dims=None):
+ self._verifySolveAllWays(x, y, (np.float32, np.float64), batch_dims)
+
+ def _verifySolveAllWaysComplex(self, x, y, batch_dims=None):
+ self._verifySolveAllWays(x, y, (np.complex64, np.complex128), batch_dims)
def _verifySolve(self,
x,
@@ -46,8 +53,9 @@ class MatrixTriangularSolveOpTest(test.TestCase):
lower=True,
adjoint=False,
batch_dims=None,
- use_placeholder=False):
- for np_type in [np.float32, np.float64]:
+ use_placeholder=False,
+ dtypes=(np.float32, np.float64)):
+ for np_type in dtypes:
a = x.astype(np_type)
b = y.astype(np_type)
# For numpy.solve we have to explicitly zero out the strictly
@@ -89,22 +97,48 @@ class MatrixTriangularSolveOpTest(test.TestCase):
# 1x1 matrix, single rhs.
matrix = np.array([[0.1]])
rhs0 = np.array([[1.]])
- self._verifySolveAllWays(matrix, rhs0)
+ self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrices, single right-hand side.
matrix = np.array([[1., 2.], [3., 4.]])
rhs0 = np.array([[1.], [1.]])
- self._verifySolveAllWays(matrix, rhs0)
+ self._verifySolveAllWaysReal(matrix, rhs0)
# 2x2 matrices, 3 right-hand sides.
rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]])
- self._verifySolveAllWays(matrix, rhs1)
+ self._verifySolveAllWaysReal(matrix, rhs1)
+
+ def testSolveComplex(self):
+ # 1x1 matrix, single rhs.
+ matrix = np.array([[0.1 + 1j * 0.1]])
+ rhs0 = np.array([[1. + 1j]])
+ self._verifySolveAllWaysComplex(matrix, rhs0)
+ # 2x2 matrices, single right-hand side.
+ matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
+ matrix += 1j * matrix
+ rhs0 = np.array([[1.], [1.]]).astype(np.complex64)
+ rhs0 += 1j * rhs0
+ self._verifySolveAllWaysComplex(matrix, rhs0)
+ # 2x2 matrices, 3 right-hand sides.
+ rhs1 = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
+ rhs1 += 1j * rhs1
+ self._verifySolveAllWaysComplex(matrix, rhs1)
def testSolveBatch(self):
matrix = np.array([[1., 2.], [3., 4.]])
rhs = np.array([[1., 0., 1.], [0., 1., 1.]])
# Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
- self._verifySolveAllWays(matrix, rhs, batch_dims=[2, 3])
+ self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[2, 3])
+ # Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
+ self._verifySolveAllWaysReal(matrix, rhs, batch_dims=[3, 2])
+
+ def testSolveBatchComplex(self):
+ matrix = np.array([[1., 2.], [3., 4.]]).astype(np.complex64)
+ matrix += 1j * matrix
+ rhs = np.array([[1., 0., 1.], [0., 1., 1.]]).astype(np.complex64)
+ rhs += 1j * rhs
+ # Batch of 2x3x2x2 matrices, 2x3x2x3 right-hand sides.
+ self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[2, 3])
# Batch of 3x2x2x2 matrices, 3x2x2x3 right-hand sides.
- self._verifySolveAllWays(matrix, rhs, batch_dims=[3, 2])
+ self._verifySolveAllWaysComplex(matrix, rhs, batch_dims=[3, 2])
def testNonSquareMatrix(self):
# A non-square matrix should cause an error.
diff --git a/tensorflow/python/ops/linalg_grad.py b/tensorflow/python/ops/linalg_grad.py
index edd4e96aa4..b479b6ac60 100644
--- a/tensorflow/python/ops/linalg_grad.py
+++ b/tensorflow/python/ops/linalg_grad.py
@@ -57,7 +57,24 @@ def _MatrixDeterminantGrad(op, grad):
@ops.RegisterGradient("Cholesky")
def _CholeskyGrad(op, grad):
"""Gradient for Cholesky."""
- return linalg_ops.cholesky_grad(op.outputs[0], grad)
+
+ # Gradient is l^{-H} @ ((l^{H} @ grad) * (tril(ones)-1/2*eye)) @ l^{-1}
+ l = op.outputs[0]
+ num_rows = array_ops.shape(l)[-1]
+ batch_shape = array_ops.shape(l)[:-2]
+ l_inverse = linalg_ops.matrix_triangular_solve(
+ l, linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=l.dtype))
+
+ middle = math_ops.matmul(l, grad, adjoint_a=True)
+ middle = array_ops.matrix_set_diag(middle,
+ 0.5 * array_ops.matrix_diag_part(middle))
+ middle = array_ops.matrix_band_part(middle, -1, 0)
+
+ grad_a = math_ops.matmul(
+ math_ops.matmul(l_inverse, middle, adjoint_a=True), l_inverse)
+
+ grad_a += math_ops.conj(array_ops.matrix_transpose(grad_a))
+ return grad_a * 0.5
@ops.RegisterGradient("MatrixSolve")