aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/linalg/linear_operator_diag.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/linalg/linear_operator_diag.py')
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_diag.py5
1 files changed, 4 insertions, 1 deletions
diff --git a/tensorflow/python/ops/linalg/linear_operator_diag.py b/tensorflow/python/ops/linalg/linear_operator_diag.py
index 5beaea65a5..ed53decc00 100644
--- a/tensorflow/python/ops/linalg/linear_operator_diag.py
+++ b/tensorflow/python/ops/linalg/linear_operator_diag.py
@@ -231,8 +231,11 @@ class LinearOperatorDiag(linear_operator.LinearOperator):
return math_ops.reduce_prod(self._diag, reduction_indices=[-1])
def _log_abs_determinant(self):
- return math_ops.reduce_sum(
+ log_det = math_ops.reduce_sum(
math_ops.log(math_ops.abs(self._diag)), reduction_indices=[-1])
+ if self.dtype.is_complex:
+ log_det = math_ops.cast(log_det, dtype=self.dtype)
+ return log_det
def _solve(self, rhs, adjoint=False, adjoint_arg=False):
diag_term = math_ops.conj(self._diag) if adjoint else self._diag