aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_low_rank_update.py')
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_low_rank_update.py31
1 files changed, 9 insertions, 22 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
index 08e5896e10..2b2bf80f27 100644
--- a/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
+++ b/tensorflow/python/ops/linalg/linear_operator_low_rank_update.py
@@ -18,16 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.linalg import linear_operator
from tensorflow.python.ops.linalg import linear_operator_diag
from tensorflow.python.ops.linalg import linear_operator_identity
from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util.tf_export import tf_export
__all__ = [
@@ -153,8 +152,7 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
`is_X` matrix property hints, which will trigger the appropriate code path.
Args:
- base_operator: Shape `[B1,...,Bb, M, N]` real `float16`, `float32` or
- `float64` `LinearOperator`. This is `L` above.
+ base_operator: Shape `[B1,...,Bb, M, N]`.
u: Shape `[B1,...,Bb, M, K]` `Tensor` of same `dtype` as `base_operator`.
This is `U` above.
diag_update: Optional shape `[B1,...,Bb, K]` `Tensor` with same `dtype`
@@ -183,23 +181,12 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
Raises:
ValueError: If `is_X` flags are set in an inconsistent way.
"""
- # TODO(langmore) support complex types.
- # Complex types are not allowed due to tf.cholesky() requiring float.
- # If complex dtypes are allowed, we update the following
- # 1. is_diag_update_positive should still imply that `diag > 0`, but we need
- # to remind the user that this implies diag is real. This is needed
- # because if diag has non-zero imaginary part, it will not be
- # self-adjoint positive definite.
dtype = base_operator.dtype
- allowed_dtypes = [
- dtypes.float16,
- dtypes.float32,
- dtypes.float64,
- ]
- if dtype not in allowed_dtypes:
- raise TypeError(
- "Argument matrix must have dtype in %s. Found: %s"
- % (allowed_dtypes, dtype))
+
+ if diag_update is not None:
+ if is_diag_update_positive and dtype.is_complex:
+ logging.warn("Note: setting is_diag_update_positive with a complex "
+ "dtype means that diagonal is real and positive.")
if diag_update is None:
if is_diag_update_positive is False:
@@ -271,8 +258,6 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
self._set_diag_operators(diag_update, is_diag_update_positive)
self._is_diag_update_positive = is_diag_update_positive
- check_ops.assert_same_float_dtype((base_operator, self.u, self.v,
- self._diag_update))
self._check_shapes()
# Pre-compute the so-called "capacitance" matrix
@@ -407,6 +392,8 @@ class LinearOperatorLowRankUpdate(linear_operator.LinearOperator):
else:
det_c = linalg_ops.matrix_determinant(self._capacitance)
log_abs_det_c = math_ops.log(math_ops.abs(det_c))
+ if self.dtype.is_complex:
+ log_abs_det_c = math_ops.cast(log_abs_det_c, dtype=self.dtype)
return log_abs_det_c + log_abs_det_d + log_abs_det_l