diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py | 29 |
1 files changed, 18 insertions, 11 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py index 34b35a4ffb..0e38dbd48d 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py @@ -49,12 +49,6 @@ class BaseLinearOperatorLowRankUpdatetest(object): _use_v = None @property - def _dtypes_to_test(self): - # TODO(langmore) Test complex types once cholesky works with them. - # See comment in LinearOperatorLowRankUpdate.__init__. - return [dtypes.float32, dtypes.float64] - - @property def _operator_build_infos(self): build_info = linear_operator_test_util.OperatorBuildInfo # Previously we had a (2, 10, 10) shape at the end. We did this to test the @@ -68,6 +62,15 @@ class BaseLinearOperatorLowRankUpdatetest(object): build_info((3, 4, 4)), build_info((2, 1, 4, 4))] + def _gen_positive_diag(self, dtype, diag_shape): + if dtype.is_complex: + diag = linear_operator_test_util.random_uniform( + diag_shape, minval=1e-4, maxval=1., dtype=dtypes.float32) + return math_ops.cast(diag, dtype=dtype) + + return linear_operator_test_util.random_uniform( + diag_shape, minval=1e-4, maxval=1., dtype=dtype) + def _operator_and_matrix(self, build_info, dtype, use_placeholder): # Recall A = L + UDV^H shape = list(build_info.shape) @@ -78,8 +81,7 @@ class BaseLinearOperatorLowRankUpdatetest(object): # base_operator L will be a symmetric positive definite diagonal linear # operator, with condition number as high as 1e4. - base_diag = linear_operator_test_util.random_uniform( - diag_shape, minval=1e-4, maxval=1., dtype=dtype) + base_diag = self._gen_positive_diag(dtype, diag_shape) lin_op_base_diag = base_diag # U @@ -94,8 +96,7 @@ class BaseLinearOperatorLowRankUpdatetest(object): # D if self._is_diag_update_positive: - diag_update = linear_operator_test_util.random_uniform( - diag_update_shape, minval=1e-4, maxval=1., dtype=dtype) + diag_update = self._gen_positive_diag(dtype, diag_update_shape) else: diag_update = linear_operator_test_util.random_normal( diag_update_shape, stddev=1e-4, dtype=dtype) @@ -110,7 +111,9 @@ class BaseLinearOperatorLowRankUpdatetest(object): diag_update, shape=None) base_operator = linalg.LinearOperatorDiag( - lin_op_base_diag, is_positive_definite=True) + lin_op_base_diag, + is_positive_definite=True, + is_self_adjoint=True) operator = linalg.LinearOperatorLowRankUpdate( base_operator, @@ -169,6 +172,7 @@ class LinearOperatorLowRankUpdatetestWithDiagUseCholesky( self._rtol[dtypes.float32] = 1e-5 self._atol[dtypes.float64] = 1e-10 self._rtol[dtypes.float64] = 1e-10 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( @@ -188,6 +192,7 @@ class LinearOperatorLowRankUpdatetestWithDiagCannotUseCholesky( self._rtol[dtypes.float32] = 1e-4 self._atol[dtypes.float64] = 1e-9 self._rtol[dtypes.float64] = 1e-9 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( @@ -206,6 +211,7 @@ class LinearOperatorLowRankUpdatetestNoDiagUseCholesky( self._rtol[dtypes.float32] = 1e-5 self._atol[dtypes.float64] = 1e-10 self._rtol[dtypes.float64] = 1e-10 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( @@ -225,6 +231,7 @@ class LinearOperatorLowRankUpdatetestNoDiagCannotUseCholesky( self._rtol[dtypes.float32] = 1e-4 self._atol[dtypes.float64] = 1e-9 self._rtol[dtypes.float64] = 1e-9 + self._rtol[dtypes.complex64] = 1e-4 class LinearOperatorLowRankUpdatetestWithDiagNotSquare( |