aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_lower_triangular.py')
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_lower_triangular.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
index fb1eb2fedb..ca6d3f5405 100644
--- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
+++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py
@@ -119,8 +119,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
Args:
tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`.
The lower triangular part of `tril` defines this operator. The strictly
- upper triangle is ignored. Allowed dtypes: `float16`, `float32`,
- `float64`.
+ upper triangle is ignored.
is_non_singular: Expect that this operator is non-singular.
This operator is non-singular if and only if its diagonal elements are
all non-zero.
@@ -137,7 +136,6 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
name: A name for this `LinearOperator`.
Raises:
- TypeError: If `diag.dtype` is not an allowed type.
ValueError: If `is_square` is `False`.
"""
@@ -163,12 +161,12 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator):
def _check_tril(self, tril):
"""Static check of the `tril` argument."""
- # TODO(langmore) Add complex types once matrix_triangular_solve works for
- # them.
allowed_dtypes = [
dtypes.float16,
dtypes.float32,
dtypes.float64,
+ dtypes.complex64,
+ dtypes.complex128,
]
dtype = tril.dtype
if dtype not in allowed_dtypes: