diff options
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_low_rank_update.py')
-rw-r--r-- | tensorflow/python/ops/linalg/linear_operator_low_rank_update.py | 31 |
1 files changed, 9 insertions, 22 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py index 08e5896e10..2b2bf80f27 100644 --- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py +++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py @@ -18,16 +18,15 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function -from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import check_ops from tensorflow.python.ops import linalg_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops.linalg import linear_operator from tensorflow.python.ops.linalg import linear_operator_diag from tensorflow.python.ops.linalg import linear_operator_identity from tensorflow.python.ops.linalg import linear_operator_util +from tensorflow.python.platform import tf_logging as logging from tensorflow.python.util.tf_export import tf_export __all__ = [ @@ -153,8 +152,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): `is_X` matrix property hints, which will trigger the appropriate code path. Args: - base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or - `float64` `LinearOperator`. This is `L` above. + base_operator: Shape `[B1,...,Bb, M, N]`. u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`. This is `U` above. diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype` @@ -183,23 +181,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): Raises: ValueError: If `is_X` flags are set in an inconsistent way. """ - # TODO(langmore) support complex types. - # Complex types are not allowed due to tf.cholesky() requiring float. - # If complex dtypes are allowed, we update the following - # 1. is_diag_update_positive should still imply that `diag > 0`, but we need - # to remind the user that this implies diag is real. This is needed - # because if diag has non-zero imaginary part, it will not be - # self-adjoint positive definite. dtype = base_operator.dtype - allowed_dtypes = [ - dtypes.float16, - dtypes.float32, - dtypes.float64, - ] - if dtype not in allowed_dtypes: - raise TypeError( - "Argument matrix must have dtype in %s. Found: %s" - % (allowed_dtypes, dtype)) + + if diag_update is not None: + if is_diag_update_positive and dtype.is_complex: + logging.warn("Note: setting is_diag_update_positive with a complex " + "dtype means that diagonal is real and positive.") if diag_update is None: if is_diag_update_positive is False: @@ -271,8 +258,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): self._set_diag_operators(diag_update, is_diag_update_positive) self._is_diag_update_positive = is_diag_update_positive - check_ops.assert_same_float_dtype((base_operator, self.u, self.v, - self._diag_update)) self._check_shapes() # Pre-compute the so-called "capacitance" matrix @@ -407,6 +392,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator): else: det_c = linalg_ops.matrix_determinant(self._capacitance) log_abs_det_c = math_ops.log(math_ops.abs(det_c)) + if self.dtype.is_complex: + log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype) return log_abs_det_c + log_abs_det_d + log_abs_det_l |