aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py')
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py21
1 files changed, 8 insertions, 13 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index a57d2f085e..167c6cacd1 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -38,28 +38,23 @@ class LinearOperatorLowerTriangularTest(
# matrix_triangular_solve.
return [dtypes.float32, dtypes.float64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
# Upper triangle will be nonzero, but ignored.
# Use a diagonal that ensures this matrix is well conditioned.
tril = linear_operator_test_util.random_tril_matrix(
shape, dtype=dtype, force_well_conditioned=True, remove_upper=False)
+ lin_op_tril = tril
+
if use_placeholder:
- tril_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate the tril here because (i) you cannot feed a tensor, and (ii)
- # tril is random and we want the same value used for both mat and
- # feed_dict.
- tril = tril.eval()
- operator = linalg.LinearOperatorLowerTriangular(tril_ph)
- feed_dict = {tril_ph: tril}
- else:
- operator = linalg.LinearOperatorLowerTriangular(tril)
- feed_dict = None
+ lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None)
+
+ operator = linalg.LinearOperatorLowerTriangular(lin_op_tril)
- mat = array_ops.matrix_band_part(tril, -1, 0)
+ matrix = array_ops.matrix_band_part(tril, -1, 0)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.