diff options
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_lower_triangular.py')
-rw-r--r-- | tensorflow/python/ops/linalg/linear_operator_lower_triangular.py | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py index fb1eb2fedb..ca6d3f5405 100644 --- a/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py +++ b/tensorflow/python/ops/linalg/linear_operator_lower_triangular.py @@ -119,8 +119,7 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): Args: tril: Shape `[B1,...,Bb, N, N]` with `b >= 0`, `N >= 0`. The lower triangular part of `tril` defines this operator. The strictly - upper triangle is ignored. Allowed dtypes: `float16`, `float32`, - `float64`. + upper triangle is ignored. is_non_singular: Expect that this operator is non-singular. This operator is non-singular if and only if its diagonal elements are all non-zero. @@ -137,7 +136,6 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): name: A name for this `LinearOperator`. Raises: - TypeError: If `diag.dtype` is not an allowed type. ValueError: If `is_square` is `False`. """ @@ -163,12 +161,12 @@ class LinearOperatorLowerTriangular(linear_operator.LinearOperator): def _check_tril(self, tril): """Static check of the `tril` argument.""" - # TODO(langmore) Add complex types once matrix_triangular_solve works for - # them. allowed_dtypes = [ dtypes.float16, dtypes.float32, dtypes.float64, + dtypes.complex64, + dtypes.complex128, ] dtype = tril.dtype if dtype not in allowed_dtypes: |