diff options
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_diag.py')
-rw-r--r-- | tensorflow/python/ops/linalg/linear_operator_diag.py | 5 |
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py index 5beaea65a5..ed53decc00 100644 --- a/tensorflow/python/ops/linalg/linear_operator_diag.py +++ b/tensorflow/python/ops/linalg/linear_operator_diag.py @@ -231,8 +231,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator): return math_ops.reduce_prod(self._diag, reduction_indices=[-1]) def _log_abs_determinant(self): - return math_ops.reduce_sum( + log_det = math_ops.reduce_sum( math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1]) + if self.dtype.is_complex: + log_det = math_ops.cast(log_det, dtype=self.dtype) + return log_det def _solve(self, rhs, adjoint=False, adjoint_arg=False): diag_term = math_ops.conj(self._diag) if adjoint else self._diag |