aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/ops/linalg_ops.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/ops/linalg_ops.cc')
-rw-r--r--tensorflow/core/ops/linalg_ops.cc310
1 files changed, 297 insertions, 13 deletions
diff --git a/tensorflow/core/ops/linalg_ops.cc b/tensorflow/core/ops/linalg_ops.cc
index b8496d972d..53e2360d23 100644
--- a/tensorflow/core/ops/linalg_ops.cc
+++ b/tensorflow/core/ops/linalg_ops.cc
@@ -202,7 +202,17 @@ REGISTER_OP("MatrixDeterminant")
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
c->set_output(0, out);
return Status::OK();
- });
+ })
+ .Doc(R"doc(
+Computes the determinant of one or more square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor containing the determinants
+for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[...]`.
+)doc");
REGISTER_OP("LogMatrixDeterminant")
.Input("input: T")
@@ -225,33 +235,126 @@ REGISTER_OP("LogMatrixDeterminant")
TF_RETURN_IF_ERROR(c->Subshape(input, 0, -2, &out));
c->set_output(1, out);
return Status::OK();
- });
+ })
+ .Doc(R"doc(
+Computes the sign and the log of the absolute value of the determinant of
+one or more square matrices.
+
+The input is a tensor of shape `[N, M, M]` whose inner-most 2 dimensions
+form square matrices. The outputs are two tensors containing the signs and
+absolute values of the log determinants for all N input submatrices
+`[..., :, :]` such that the determinant = sign*exp(log_abs_determinant).
+The log_abs_determinant is computed as det(P)*sum(log(diag(LU))) where LU
+is the LU decomposition of the input and P is the corresponding
+permutation matrix.
+
+input: Shape is `[N, M, M]`.
+sign: The signs of the log determinants of the inputs. Shape is `[N]`.
+log_abs_determinant: The logs of the absolute values of the determinants
+of the N input matrices. Shape is `[N]`.
+)doc");
REGISTER_OP("MatrixInverse")
.Input("input: T")
.Output("output: T")
.Attr("adjoint: bool = False")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(BatchUnchangedSquareShapeFn);
+ .SetShapeFn(BatchUnchangedSquareShapeFn)
+ .Doc(R"doc(
+Computes the inverse of one or more square invertible matrices or their
+adjoints (conjugate transposes).
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor of the same shape as the input
+containing the inverse for all input submatrices `[..., :, :]`.
+
+The op uses LU decomposition with partial pivoting to compute the inverses.
+
+If a matrix is not invertible there is no guarantee what the op does. It
+may detect the condition and raise an exception or it may simply return a
+garbage result.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+
+@compatibility(numpy)
+Equivalent to np.linalg.inv
+@end_compatibility
+)doc");
REGISTER_OP("MatrixExponential")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(BatchUnchangedSquareShapeFn);
+ .SetShapeFn(BatchUnchangedSquareShapeFn)
+ .Doc(R"doc(
+Computes the matrix exponential of one or more square matrices:
+
+exp(A) = \sum_{n=0}^\infty A^n/n!
+
+The exponential is computed using a combination of the scaling and squaring
+method and the Pade approximation. Details can be founds in:
+Nicholas J. Higham, "The scaling and squaring method for the matrix exponential
+revisited," SIAM J. Matrix Anal. Applic., 26:1179-1193, 2005.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. The output is a tensor of the same shape as the input
+containing the exponential for all input submatrices `[..., :, :]`.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+
+@compatibility(scipy)
+Equivalent to scipy.linalg.expm
+@end_compatibility
+)doc");
REGISTER_OP("Cholesky")
.Input("input: T")
.Output("output: T")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(BatchUnchangedSquareShapeFn);
+ .SetShapeFn(BatchUnchangedSquareShapeFn)
+ .Doc(R"doc(
+Computes the Cholesky decomposition of one or more square matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices.
+
+The input has to be symmetric and positive definite. Only the lower-triangular
+part of the input will be used for this operation. The upper-triangular part
+will not be read.
+
+The output is a tensor of the same shape as the input
+containing the Cholesky decompositions for all input submatrices `[..., :, :]`.
+
+**Note**: The gradient computation on GPU is faster for large matrices but
+not for large batch dimensions when the submatrices are small. In this
+case it might be faster to use the CPU.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M, M]`.
+)doc");
REGISTER_OP("CholeskyGrad")
.Input("l: T")
.Input("grad: T")
.Output("output: T")
.Attr("T: {float, double}")
- .SetShapeFn(BatchUnchangedSquareShapeFn);
+ .SetShapeFn(BatchUnchangedSquareShapeFn)
+ .Doc(R"doc(
+Computes the reverse mode backpropagated gradient of the Cholesky algorithm.
+
+For an explanation see "Differentiation of the Cholesky algorithm" by
+Iain Murray http://arxiv.org/abs/1602.07527.
+
+l: Output of batch Cholesky algorithm l = cholesky(A). Shape is `[..., M, M]`.
+ Algorithm depends only on lower triangular part of the innermost matrices of
+ this tensor.
+grad: df/dl where f is some scalar function. Shape is `[..., M, M]`.
+ Algorithm depends only on lower triangular part of the innermost matrices of
+ this tensor.
+output: Symmetrized version of df/dA . Shape is `[..., M, M]`
+)doc");
REGISTER_OP("SelfAdjointEig")
.Input("input: T")
@@ -271,7 +374,20 @@ REGISTER_OP("SelfAdjointEig")
TF_RETURN_IF_ERROR(c->Concatenate(s, c->Matrix(d_plus_1, d), &s));
c->set_output(0, s);
return Status::OK();
- });
+ })
+ .Doc(R"doc(
+Computes the Eigen Decomposition of a batch of square self-adjoint matrices.
+
+The input is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices, with the same constraints as the single matrix
+SelfAdjointEig.
+
+The result is a [..., M+1, M] matrix with [..., 0,:] containing the
+eigenvalues, and subsequent [...,1:, :] containing the eigenvectors.
+
+input: Shape is `[..., M, M]`.
+output: Shape is `[..., M+1, M]`.
+)doc");
REGISTER_OP("SelfAdjointEigV2")
.Input("input: T")
@@ -279,7 +395,27 @@ REGISTER_OP("SelfAdjointEigV2")
.Output("v: T")
.Attr("compute_v: bool = True")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(SelfAdjointEigV2ShapeFn);
+ .SetShapeFn(SelfAdjointEigV2ShapeFn)
+ .Doc(R"doc(
+Computes the eigen decomposition of one or more square self-adjoint matrices.
+
+Computes the eigenvalues and (optionally) eigenvectors of each inner matrix in
+`input` such that `input[..., :, :] = v[..., :, :] * diag(e[..., :])`.
+
+```python
+# a is a tensor.
+# e is a tensor of eigenvalues.
+# v is a tensor of eigenvectors.
+e, v = self_adjoint_eig(a)
+e = self_adjoint_eig(a, compute_v=False)
+```
+
+input: `Tensor` input of shape `[N, N]`.
+compute_v: If `True` then eigenvectors will be computed and returned in `v`.
+ Otherwise, only the eigenvalues will be computed.
+e: Eigenvalues. Shape is `[N]`.
+v: Eigenvectors. Shape is `[N, N]`.
+)doc");
REGISTER_OP("MatrixSolve")
.Input("matrix: T")
@@ -289,7 +425,23 @@ REGISTER_OP("MatrixSolve")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return MatrixSolveShapeFn(c, true /* square (*/);
- });
+ })
+ .Doc(R"doc(
+Solves systems of linear equations.
+
+`Matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions
+form square matrices. `Rhs` is a tensor of shape `[..., M, K]`. The `output` is
+a tensor shape `[..., M, K]`. If `adjoint` is `False` then each output matrix
+satisfies `matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+If `adjoint` is `True` then each output matrix satisfies
+`adjoint(matrix[..., :, :]) * output[..., :, :] = rhs[..., :, :]`.
+
+matrix: Shape is `[..., M, M]`.
+rhs: Shape is `[..., M, K]`.
+output: Shape is `[..., M, K]`.
+adjoint: Boolean indicating whether to solve with `matrix` or its (block-wise)
+ adjoint.
+)doc");
REGISTER_OP("MatrixTriangularSolve")
.Input("matrix: T")
@@ -300,7 +452,37 @@ REGISTER_OP("MatrixTriangularSolve")
.Attr("T: {double, float, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
return MatrixSolveShapeFn(c, true /* square (*/);
- });
+ })
+ .Doc(R"doc(
+Solves systems of linear equations with upper or lower triangular matrices by
+backsubstitution.
+
+`matrix` is a tensor of shape `[..., M, M]` whose inner-most 2 dimensions form
+square matrices. If `lower` is `True` then the strictly upper triangular part
+of each inner-most matrix is assumed to be zero and not accessed.
+If `lower` is False then the strictly lower triangular part of each inner-most
+matrix is assumed to be zero and not accessed.
+`rhs` is a tensor of shape `[..., M, K]`.
+
+The output is a tensor of shape `[..., M, K]`. If `adjoint` is
+`True` then the innermost matrices in `output` satisfy matrix equations
+`matrix[..., :, :] * output[..., :, :] = rhs[..., :, :]`.
+If `adjoint` is `False` then the strictly then the innermost matrices in
+`output` satisfy matrix equations
+`adjoint(matrix[..., i, k]) * output[..., k, j] = rhs[..., i, j]`.
+
+matrix: Shape is `[..., M, M]`.
+rhs: Shape is `[..., M, K]`.
+output: Shape is `[..., M, K]`.
+lower: Boolean indicating whether the innermost matrices in `matrix` are
+ lower or upper triangular.
+adjoint: Boolean indicating whether to solve with `matrix` or its (block-wise)
+ adjoint.
+
+@compatibility(numpy)
+Equivalent to np.linalg.triangular_solve
+@end_compatibility
+)doc");
REGISTER_OP("MatrixSolveLs")
.Input("matrix: T")
@@ -313,7 +495,54 @@ REGISTER_OP("MatrixSolveLs")
ShapeHandle l2_regularizer;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &l2_regularizer));
return MatrixSolveShapeFn(c, false /* square */);
- });
+ })
+ .Doc(R"doc(
+Solves one or more linear least-squares problems.
+
+`matrix` is a tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+form real or complex matrices of size `[M, N]`. `Rhs` is a tensor of the same
+type as `matrix` and shape `[..., M, K]`.
+The output is a tensor shape `[..., N, K]` where each output matrix solves
+each of the equations
+`matrix[..., :, :]` * `output[..., :, :]` = `rhs[..., :, :]`
+in the least squares sense.
+
+We use the following notation for (complex) matrix and right-hand sides
+in the batch:
+
+`matrix`=\\(A \in \mathbb{C}^{m \times n}\\),
+`rhs`=\\(B \in \mathbb{C}^{m \times k}\\),
+`output`=\\(X \in \mathbb{C}^{n \times k}\\),
+`l2_regularizer`=\\(\lambda \in \mathbb{R}\\).
+
+If `fast` is `True`, then the solution is computed by solving the normal
+equations using Cholesky decomposition. Specifically, if \\(m \ge n\\) then
+\\(X = (A^H A + \lambda I)^{-1} A^H B\\), which solves the least-squares
+problem \\(X = \mathrm{argmin}_{Z \in \Re^{n \times k} } ||A Z - B||_F^2 +
+\lambda ||Z||_F^2\\). If \\(m \lt n\\) then `output` is computed as
+\\(X = A^H (A A^H + \lambda I)^{-1} B\\), which (for \\(\lambda = 0\\)) is the
+minimum-norm solution to the under-determined linear system, i.e.
+\\(X = \mathrm{argmin}_{Z \in \mathbb{C}^{n \times k} } ||Z||_F^2 \\),
+subject to \\(A Z = B\\). Notice that the fast path is only numerically stable
+when \\(A\\) is numerically full rank and has a condition number
+\\(\mathrm{cond}(A) \lt \frac{1}{\sqrt{\epsilon_{mach} } }\\) or\\(\lambda\\) is
+sufficiently large.
+
+If `fast` is `False` an algorithm based on the numerically robust complete
+orthogonal decomposition is used. This computes the minimum-norm
+least-squares solution, even when \\(A\\) is rank deficient. This path is
+typically 6-7 times slower than the fast path. If `fast` is `False` then
+`l2_regularizer` is ignored.
+
+matrix: Shape is `[..., M, N]`.
+rhs: Shape is `[..., M, K]`.
+output: Shape is `[..., N, K]`.
+l2_regularizer: Scalar tensor.
+
+@compatibility(numpy)
+Equivalent to np.linalg.lstsq
+@end_compatibility
+)doc");
REGISTER_OP("Qr")
.Input("input: T")
@@ -321,7 +550,31 @@ REGISTER_OP("Qr")
.Output("r: T")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(QrShapeFn);
+ .SetShapeFn(QrShapeFn)
+ .Doc(R"doc(
+Computes the QR decompositions of one or more matrices.
+
+Computes the QR decomposition of each inner matrix in `tensor` such that
+`tensor[..., :, :] = q[..., :, :] * r[..., :,:])`
+
+```python
+# a is a tensor.
+# q is a tensor of orthonormal matrices.
+# r is a tensor of upper triangular matrices.
+q, r = qr(a)
+q_full, r_full = qr(a, full_matrices=True)
+```
+
+input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+ form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+q: Orthonormal basis for range of `a`. If `full_matrices` is `False` then
+ shape is `[..., M, P]`; if `full_matrices` is `True` then shape is
+ `[..., M, M]`.
+r: Triangular factor. If `full_matrices` is `False` then shape is
+ `[..., P, N]`. If `full_matrices` is `True` then shape is `[..., M, N]`.
+full_matrices: If true, compute full-sized `q` and `r`. If false
+ (the default), compute only the leading `P` columns of `q`.
+)doc");
REGISTER_OP("Svd")
.Input("input: T")
@@ -331,7 +584,38 @@ REGISTER_OP("Svd")
.Attr("compute_uv: bool = True")
.Attr("full_matrices: bool = False")
.Attr("T: {double, float, complex64, complex128}")
- .SetShapeFn(SvdShapeFn);
+ .SetShapeFn(SvdShapeFn)
+ .Doc(R"doc(
+Computes the singular value decompositions of one or more matrices.
+
+Computes the SVD of each inner matrix in `input` such that
+`input[..., :, :] = u[..., :, :] * diag(s[..., :, :]) * transpose(v[..., :, :])`
+
+```python
+# a is a tensor containing a batch of matrices.
+# s is a tensor of singular values for each matrix.
+# u is the tensor containing of left singular vectors for each matrix.
+# v is the tensor containing of right singular vectors for each matrix.
+s, u, v = svd(a)
+s, _, _ = svd(a, compute_uv=False)
+```
+
+input: A tensor of shape `[..., M, N]` whose inner-most 2 dimensions
+ form matrices of size `[M, N]`. Let `P` be the minimum of `M` and `N`.
+s: Singular values. Shape is `[..., P]`.
+u: Left singular vectors. If `full_matrices` is `False` then shape is
+ `[..., M, P]`; if `full_matrices` is `True` then shape is
+ `[..., M, M]`. Undefined if `compute_uv` is `False`.
+v: Left singular vectors. If `full_matrices` is `False` then shape is
+ `[..., N, P]`. If `full_matrices` is `True` then shape is `[..., N, N]`.
+ Undefined if `compute_uv` is false.
+compute_uv: If true, left and right singular vectors will be
+ computed and returned in `u` and `v`, respectively.
+ If false, `u` and `v` are not set and should never referenced.
+full_matrices: If true, compute full-sized `u` and `v`. If false
+ (the default), compute only the leading `P` singular vectors.
+ Ignored if `compute_uv` is `False`.
+)doc");
// Deprecated op registrations: