aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-27 20:05:29 -0700
committerGravatar Gunhan Gulsoy <gunan@google.com>2018-06-28 21:37:43 -0700
commit7d9621e4626b4ec3553c16fe803c21c25dd06c30 (patch)
treeadc759f04b2b54605d01e8328bef5c2b2f80a610
parent60795c45bb112b12accc02882442d8826fbcce38 (diff)
Update LinearOperator tests to use array_ops.placeholder_with_default.
PiperOrigin-RevId: 202412660
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD21
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py35
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py110
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py68
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py22
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py73
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py31
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py31
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py69
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py21
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_test_util.py41
11 files changed, 207 insertions, 315 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 0123adc2c3..69d3aa4017 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -107,6 +107,10 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -124,7 +128,10 @@ cuda_py_test(
"//tensorflow/python:random_ops",
],
shard_count = 5,
- tags = ["optonly"], # Test is flaky without optimization.
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -141,6 +148,10 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -178,6 +189,10 @@ cuda_py_test(
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
cuda_py_test(
@@ -214,4 +229,8 @@ cuda_py_test(
"//tensorflow/python:platform_test",
],
shard_count = 5,
+ tags = [
+ "noasan",
+ "optonly",
+ ],
)
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
index 2b80f01b73..3ede2aceaa 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_block_diag_test.py
@@ -80,7 +80,7 @@ class SquareLinearOperatorBlockDiagTest(
build_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
expected_blocks = (
build_info.__dict__["blocks"] if "blocks" in build_info.__dict__
@@ -91,26 +91,19 @@ class SquareLinearOperatorBlockDiagTest(
for block_shape in expected_blocks
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in expected_blocks
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = self.evaluate(matrices)
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorFullMatrix(
- m_ph, is_square=True) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = block_diag.LinearOperatorBlockDiag(
- [linalg.LinearOperatorFullMatrix(
- m, is_square=True) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = block_diag.LinearOperatorBlockDiag(
+ [linalg.LinearOperatorFullMatrix(
+ l, is_square=True) for l in lin_op_matrices])
+
+ # Should be auto-set.
+ self.assertTrue(operator.is_square)
# Broadcast the shapes.
expected_shape = list(build_info.shape)
@@ -123,7 +116,7 @@ class SquareLinearOperatorBlockDiagTest(
block_diag_dense.set_shape(
expected_shape[:-2] + [expected_shape[-1], expected_shape[-1]])
- return operator, block_diag_dense, feed_dict
+ return operator, block_diag_dense
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 5713d16969..7261d4bb3b 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -95,7 +95,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
# real, the matrix will not be real.
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating real spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -107,22 +107,18 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
# zero, so the operator will still be self-adjoint.
spectrum = math_ops.cast(spectrum, dtype)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, is_self_adjoint=True, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -149,7 +145,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -172,22 +168,18 @@ class LinearOperatorCirculantTestHermitianSpectrum(
spectrum = math_ops.fft(h_c)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -213,7 +205,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# Will be well conditioned enough to get accurate solves.
spectrum = linear_operator_test_util.random_sign_uniform(
@@ -222,22 +214,18 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
minval=1.,
maxval=2.)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
with self.test_session():
@@ -432,7 +420,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# For this test class, we are creating Hermitian spectrums.
# We also want the spectrum to have eigenvalues bounded away from zero.
@@ -455,22 +443,18 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
spectrum = math_ops.fft2d(h_c)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant2D(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant2D(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant2D(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
class LinearOperatorCirculant2DTestNonHermitianSpectrum(
@@ -486,7 +470,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def _dtypes_to_test(self):
return [dtypes.complex64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = build_info.shape
# Will be well conditioned enough to get accurate solves.
spectrum = linear_operator_test_util.random_sign_uniform(
@@ -495,22 +479,18 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
minval=1.,
maxval=2.)
+ lin_op_spectrum = spectrum
+
if use_placeholder:
- spectrum_ph = array_ops.placeholder(dtypes.complex64)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # it is random and we want the same value used for both mat and feed_dict.
- spectrum = spectrum.eval()
- operator = linalg.LinearOperatorCirculant2D(
- spectrum_ph, input_output_dtype=dtype)
- feed_dict = {spectrum_ph: spectrum}
- else:
- operator = linalg.LinearOperatorCirculant2D(
- spectrum, input_output_dtype=dtype)
- feed_dict = None
+ lin_op_spectrum = array_ops.placeholder_with_default(
+ spectrum, shape=None)
+
+ operator = linalg.LinearOperatorCirculant2D(
+ lin_op_spectrum, input_output_dtype=dtype)
mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
with self.test_session() as sess:
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
index f96b9ccdaa..612a50bcec 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_composition_test.py
@@ -44,7 +44,7 @@ class SquareLinearOperatorCompositionTest(
self._rtol[dtypes.float32] = 1e-4
self._rtol[dtypes.complex64] = 1e-4
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
sess = ops.get_default_session()
shape = list(build_info.shape)
@@ -56,33 +56,23 @@ class SquareLinearOperatorCompositionTest(
for _ in range(num_operators)
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in range(num_operators)
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = sess.run(matrices)
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated each matrix to a numpy array.
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = linalg.LinearOperatorComposition(
+ [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices],
+ is_square=True)
+
matmul_order_list = list(reversed(matrices))
- mat = ops.convert_to_tensor(matmul_order_list[0])
+ mat = matmul_order_list[0]
for other_mat in matmul_order_list[1:]:
mat = math_ops.matmul(other_mat, mat)
- return operator, mat, feed_dict
+ return operator, mat
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
@@ -148,7 +138,7 @@ class NonSquareLinearOperatorCompositionTest(
self._rtol[dtypes.float32] = 1e-4
self._rtol[dtypes.complex64] = 1e-4
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
sess = ops.get_default_session()
shape = list(build_info.shape)
@@ -170,30 +160,22 @@ class NonSquareLinearOperatorCompositionTest(
shape_2, dtype=dtype)
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in range(num_operators)
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = sess.run(matrices)
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m_ph) for m_ph in matrices_ph])
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = linalg.LinearOperatorComposition(
- [linalg.LinearOperatorFullMatrix(m) for m in matrices])
- feed_dict = None
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated each matrix to a numpy array.
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(
+ matrix, shape=None) for matrix in matrices]
+
+ operator = linalg.LinearOperatorComposition(
+ [linalg.LinearOperatorFullMatrix(l) for l in lin_op_matrices])
+
matmul_order_list = list(reversed(matrices))
- mat = ops.convert_to_tensor(matmul_order_list[0])
+ mat = matmul_order_list[0]
for other_mat in matmul_order_list[1:]:
mat = math_ops.matmul(other_mat, mat)
- return operator, mat, feed_dict
+ return operator, mat
def test_static_shapes(self):
operators = [
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
index 0a0e31c716..83cc8c483f 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_diag_test.py
@@ -34,25 +34,21 @@ class LinearOperatorDiagTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
diag = linear_operator_test_util.random_sign_uniform(
shape[:-1], minval=1., maxval=2., dtype=dtype)
+
+ lin_op_diag = diag
+
if use_placeholder:
- diag_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate the diag here because (i) you cannot feed a tensor, and (ii)
- # diag is random and we want the same value used for both mat and
- # feed_dict.
- diag = diag.eval()
- operator = linalg.LinearOperatorDiag(diag_ph)
- feed_dict = {diag_ph: diag}
- else:
- operator = linalg.LinearOperatorDiag(diag)
- feed_dict = None
+ lin_op_diag = array_ops.placeholder_with_default(diag, shape=None)
+
+ operator = linalg.LinearOperatorDiag(lin_op_diag)
- mat = array_ops.matrix_diag(diag)
+ matrix = array_ops.matrix_diag(diag)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_positive_definite_raises_for_zero_eigenvalue(self):
# Matrix with one positive eigenvalue and one zero eigenvalue.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
index b3da623b5e..1a40a29ec6 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_full_matrix_test.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import numpy as np
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -36,30 +35,20 @@ class SquareLinearOperatorFullMatrixTest(
linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_positive_definite_matrix(
shape, dtype)
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- operator = linalg.LinearOperatorFullMatrix(matrix_ph, is_square=True)
- feed_dict = {matrix_ph: matrix}
- else:
- # is_square should be auto-detected here.
- operator = linalg.LinearOperatorFullMatrix(matrix)
- feed_dict = None
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
- return operator, mat, feed_dict
+ return operator, matrix
def test_is_x_flags(self):
# Matrix with two positive eigenvalues.
@@ -136,32 +125,20 @@ class SquareLinearOperatorFullMatrixSymmetricPositiveDefiniteTest(
def _dtypes_to_test(self):
return [dtypes.float32, dtypes.float64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_positive_definite_matrix(
shape, dtype, force_well_conditioned=True)
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- # is_square is auto-set because of self_adjoint/pd.
- operator = linalg.LinearOperatorFullMatrix(
- matrix_ph, is_self_adjoint=True, is_positive_definite=True)
- feed_dict = {matrix_ph: matrix}
- else:
- operator = linalg.LinearOperatorFullMatrix(
- matrix, is_self_adjoint=True, is_positive_definite=True)
- feed_dict = None
-
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
-
- return operator, mat, feed_dict
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
+
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
+
+ return operator, matrix
def test_is_x_flags(self):
# Matrix with two positive eigenvalues.
@@ -210,26 +187,18 @@ class NonSquareLinearOperatorFullMatrixTest(
linear_operator_test_util.NonSquareLinearOperatorDerivedClassTest):
"""Most tests done in the base class LinearOperatorDerivedClassTest."""
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
matrix = linear_operator_test_util.random_normal(shape, dtype=dtype)
+
+ lin_op_matrix = matrix
+
if use_placeholder:
- matrix_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrix = matrix.eval()
- operator = linalg.LinearOperatorFullMatrix(matrix_ph)
- feed_dict = {matrix_ph: matrix}
- else:
- operator = linalg.LinearOperatorFullMatrix(matrix)
- feed_dict = None
+ lin_op_matrix = array_ops.placeholder_with_default(matrix, shape=None)
- # Convert back to Tensor. Needed if use_placeholder, since then we have
- # already evaluated matrix to a numpy array.
- mat = ops.convert_to_tensor(matrix)
+ operator = linalg.LinearOperatorFullMatrix(lin_op_matrix, is_square=True)
- return operator, mat, feed_dict
+ return operator, matrix
def test_is_x_flags(self):
matrix = [[3., 2., 1.], [1., 1., 1.]]
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
index 59f63f949e..35dcf4417c 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_identity_test.py
@@ -43,7 +43,7 @@ class LinearOperatorIdentityTest(
# 16bit.
return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
assert shape[-1] == shape[-2]
@@ -54,13 +54,7 @@ class LinearOperatorIdentityTest(
num_rows, batch_shape=batch_shape, dtype=dtype)
mat = linalg_ops.eye(num_rows, batch_shape=batch_shape, dtype=dtype)
- # Nothing to feed since LinearOperatorIdentity takes no Tensor args.
- if use_placeholder:
- feed_dict = {}
- else:
- feed_dict = None
-
- return operator, mat, feed_dict
+ return operator, mat
def test_assert_positive_definite(self):
with self.test_session():
@@ -261,7 +255,7 @@ class LinearOperatorScaledIdentityTest(
# 16bit.
return [dtypes.float32, dtypes.float64, dtypes.complex64, dtypes.complex128]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
assert shape[-1] == shape[-2]
@@ -274,24 +268,23 @@ class LinearOperatorScaledIdentityTest(
multiplier = linear_operator_test_util.random_sign_uniform(
shape=batch_shape, minval=1., maxval=2., dtype=dtype)
- operator = linalg_lib.LinearOperatorScaledIdentity(num_rows, multiplier)
# Nothing to feed since LinearOperatorScaledIdentity takes no Tensor args.
+ lin_op_multiplier = multiplier
+
if use_placeholder:
- multiplier_ph = array_ops.placeholder(dtype=dtype)
- multiplier = multiplier.eval()
- operator = linalg_lib.LinearOperatorScaledIdentity(
- num_rows, multiplier_ph)
- feed_dict = {multiplier_ph: multiplier}
- else:
- feed_dict = None
+ lin_op_multiplier = array_ops.placeholder_with_default(
+ multiplier, shape=None)
+
+ operator = linalg_lib.LinearOperatorScaledIdentity(
+ num_rows, lin_op_multiplier)
multiplier_matrix = array_ops.expand_dims(
array_ops.expand_dims(multiplier, -1), -1)
- mat = multiplier_matrix * linalg_ops.eye(
+ matrix = multiplier_matrix * linalg_ops.eye(
num_rows, batch_shape=batch_shape, dtype=dtype)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_positive_definite_does_not_raise_when_positive(self):
with self.test_session():
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
index 784c730bbc..e26b946151 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_kronecker_test.py
@@ -101,7 +101,7 @@ class SquareLinearOperatorKroneckerTest(
def _tests_to_skip(self):
return ["det", "solve", "solve_with_broadcast"]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
expected_factors = build_info.__dict__["factors"]
matrices = [
@@ -110,26 +110,15 @@ class SquareLinearOperatorKroneckerTest(
for block_shape in expected_factors
]
+ lin_op_matrices = matrices
+
if use_placeholder:
- matrices_ph = [
- array_ops.placeholder(dtype=dtype) for _ in expected_factors
- ]
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- matrices = self.evaluate(matrices)
- operator = kronecker.LinearOperatorKronecker(
- [linalg.LinearOperatorFullMatrix(
- m_ph, is_square=True) for m_ph in matrices_ph],
- is_square=True)
- feed_dict = {m_ph: m for (m_ph, m) in zip(matrices_ph, matrices)}
- else:
- operator = kronecker.LinearOperatorKronecker(
- [linalg.LinearOperatorFullMatrix(
- m, is_square=True) for m in matrices])
- feed_dict = None
- # Should be auto-set.
- self.assertTrue(operator.is_square)
+ lin_op_matrices = [
+ array_ops.placeholder_with_default(m, shape=None) for m in matrices]
+
+ operator = kronecker.LinearOperatorKronecker(
+ [linalg.LinearOperatorFullMatrix(
+ l, is_square=True) for l in lin_op_matrices])
matrices = linear_operator_util.broadcast_matrix_batch_dims(matrices)
@@ -138,7 +127,7 @@ class SquareLinearOperatorKroneckerTest(
if not use_placeholder:
kronecker_dense.set_shape(shape)
- return operator, kronecker_dense, feed_dict
+ return operator, kronecker_dense
def test_is_x_flags(self):
# Matrix with two positive eigenvalues, 1, and 1.
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
index 8095f6419e..34b35a4ffb 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_low_rank_update_test.py
@@ -68,7 +68,7 @@ class BaseLinearOperatorLowRankUpdatetest(object):
build_info((3, 4, 4)),
build_info((2, 1, 4, 4))]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
# Recall A = L + UDV^H
shape = list(build_info.shape)
diag_shape = shape[:-1]
@@ -80,17 +80,17 @@ class BaseLinearOperatorLowRankUpdatetest(object):
# operator, with condition number as high as 1e4.
base_diag = linear_operator_test_util.random_uniform(
diag_shape, minval=1e-4, maxval=1., dtype=dtype)
- base_diag_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_base_diag = base_diag
# U
u = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
- u_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_u = u
# V
v = linear_operator_test_util.random_normal_correlated_columns(
u_perturbation_shape, dtype=dtype)
- v_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_v = v
# D
if self._is_diag_update_positive:
@@ -99,42 +99,25 @@ class BaseLinearOperatorLowRankUpdatetest(object):
else:
diag_update = linear_operator_test_util.random_normal(
diag_update_shape, stddev=1e-4, dtype=dtype)
- diag_update_ph = array_ops.placeholder(dtype=dtype)
+ lin_op_diag_update = diag_update
if use_placeholder:
- # Evaluate here because (i) you cannot feed a tensor, and (ii)
- # values are random and we want the same value used for both mat and
- # feed_dict.
- base_diag = base_diag.eval()
- u = u.eval()
- v = v.eval()
- diag_update = diag_update.eval()
-
- # In all cases, set base_operator to be positive definite.
- base_operator = linalg.LinearOperatorDiag(
- base_diag_ph, is_positive_definite=True)
-
- operator = linalg.LinearOperatorLowRankUpdate(
- base_operator,
- u=u_ph,
- v=v_ph if self._use_v else None,
- diag_update=diag_update_ph if self._use_diag_update else None,
- is_diag_update_positive=self._is_diag_update_positive)
- feed_dict = {
- base_diag_ph: base_diag,
- u_ph: u,
- v_ph: v,
- diag_update_ph: diag_update}
- else:
- base_operator = linalg.LinearOperatorDiag(
- base_diag, is_positive_definite=True)
- operator = linalg.LinearOperatorLowRankUpdate(
- base_operator,
- u,
- v=v if self._use_v else None,
- diag_update=diag_update if self._use_diag_update else None,
- is_diag_update_positive=self._is_diag_update_positive)
- feed_dict = None
+ lin_op_base_diag = array_ops.placeholder_with_default(
+ base_diag, shape=None)
+ lin_op_u = array_ops.placeholder_with_default(u, shape=None)
+ lin_op_v = array_ops.placeholder_with_default(v, shape=None)
+ lin_op_diag_update = array_ops.placeholder_with_default(
+ diag_update, shape=None)
+
+ base_operator = linalg.LinearOperatorDiag(
+ lin_op_base_diag, is_positive_definite=True)
+
+ operator = linalg.LinearOperatorLowRankUpdate(
+ base_operator,
+ lin_op_u,
+ v=lin_op_v if self._use_v else None,
+ diag_update=lin_op_diag_update if self._use_diag_update else None,
+ is_diag_update_positive=self._is_diag_update_positive)
# The matrix representing L
base_diag_mat = array_ops.matrix_diag(base_diag)
@@ -146,28 +129,28 @@ class BaseLinearOperatorLowRankUpdatetest(object):
if self._use_v and self._use_diag_update:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
- mat = base_diag_mat + math_ops.matmul(
+ matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, v, adjoint_b=True))
elif self._use_v:
# In this case, we have L + UDV^H and it isn't symmetric.
expect_use_cholesky = False
- mat = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
+ matrix = base_diag_mat + math_ops.matmul(u, v, adjoint_b=True)
elif self._use_diag_update:
# In this case, we have L + UDU^H, which is PD if D > 0, since L > 0.
expect_use_cholesky = self._is_diag_update_positive
- mat = base_diag_mat + math_ops.matmul(
+ matrix = base_diag_mat + math_ops.matmul(
u, math_ops.matmul(diag_update_mat, u, adjoint_b=True))
else:
# In this case, we have L + UU^H, which is PD since L > 0.
expect_use_cholesky = True
- mat = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
+ matrix = base_diag_mat + math_ops.matmul(u, u, adjoint_b=True)
if expect_use_cholesky:
self.assertTrue(operator._use_cholesky)
else:
self.assertFalse(operator._use_cholesky)
- return operator, mat, feed_dict
+ return operator, matrix
class LinearOperatorLowRankUpdatetestWithDiagUseCholesky(
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
index a57d2f085e..167c6cacd1 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_lower_triangular_test.py
@@ -38,28 +38,23 @@ class LinearOperatorLowerTriangularTest(
# matrix_triangular_solve.
return [dtypes.float32, dtypes.float64]
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
shape = list(build_info.shape)
# Upper triangle will be nonzero, but ignored.
# Use a diagonal that ensures this matrix is well conditioned.
tril = linear_operator_test_util.random_tril_matrix(
shape, dtype=dtype, force_well_conditioned=True, remove_upper=False)
+ lin_op_tril = tril
+
if use_placeholder:
- tril_ph = array_ops.placeholder(dtype=dtype)
- # Evaluate the tril here because (i) you cannot feed a tensor, and (ii)
- # tril is random and we want the same value used for both mat and
- # feed_dict.
- tril = tril.eval()
- operator = linalg.LinearOperatorLowerTriangular(tril_ph)
- feed_dict = {tril_ph: tril}
- else:
- operator = linalg.LinearOperatorLowerTriangular(tril)
- feed_dict = None
+ lin_op_tril = array_ops.placeholder_with_default(lin_op_tril, shape=None)
+
+ operator = linalg.LinearOperatorLowerTriangular(lin_op_tril)
- mat = array_ops.matrix_band_part(tril, -1, 0)
+ matrix = array_ops.matrix_band_part(tril, -1, 0)
- return operator, mat, feed_dict
+ return operator, matrix
def test_assert_non_singular(self):
# Singlular matrix with one positive eigenvalue and one zero eigenvalue.
diff --git a/tensorflow/python/ops/linalg/linear_operator_test_util.py b/tensorflow/python/ops/linalg/linear_operator_test_util.py
index 1b5bb9470c..78c85db557 100644
--- a/tensorflow/python/ops/linalg/linear_operator_test_util.py
+++ b/tensorflow/python/ops/linalg/linear_operator_test_util.py
@@ -102,7 +102,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
raise NotImplementedError("operator_build_infos has not been implemented.")
@abc.abstractmethod
- def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ def _operator_and_matrix(self, build_info, dtype, use_placeholder):
"""Build a batch matrix and an Operator that should have similar behavior.
Every operator acts like a (batch) matrix. This method returns both
@@ -118,9 +118,6 @@ class LinearOperatorDerivedClassTest(test.TestCase):
Returns:
operator: `LinearOperator` subclass instance.
mat: `Tensor` representing operator.
- feed_dict: Dictionary.
- If placholder is True, this must contains everything needed to be fed
- to sess.run calls at runtime to make the operator work.
"""
# Create a matrix as a numpy array with desired shape/dtype.
# Create a LinearOperator that should have the same behavior as the matrix.
@@ -189,12 +186,12 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_dense = operator.to_dense()
if not use_placeholder:
self.assertAllEqual(build_info.shape, op_dense.get_shape())
- op_dense_v, mat_v = sess.run([op_dense, mat], feed_dict=feed_dict)
+ op_dense_v, mat_v = sess.run([op_dense, mat])
self.assertAC(op_dense_v, mat_v)
def test_det(self):
@@ -204,14 +201,13 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_det = operator.determinant()
if not use_placeholder:
self.assertAllEqual(build_info.shape[:-2], op_det.get_shape())
op_det_v, mat_det_v = sess.run(
- [op_det, linalg_ops.matrix_determinant(mat)],
- feed_dict=feed_dict)
+ [op_det, linalg_ops.matrix_determinant(mat)])
self.assertAC(op_det_v, mat_det_v)
def test_log_abs_det(self):
@@ -221,7 +217,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_log_abs_det = operator.log_abs_determinant()
_, mat_log_abs_det = linalg.slogdet(mat)
@@ -229,7 +225,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
self.assertAllEqual(
build_info.shape[:-2], op_log_abs_det.get_shape())
op_log_abs_det_v, mat_log_abs_det_v = sess.run(
- [op_log_abs_det, mat_log_abs_det], feed_dict=feed_dict)
+ [op_log_abs_det, mat_log_abs_det])
self.assertAC(op_log_abs_det_v, mat_log_abs_det_v)
def _test_matmul(self, with_batch):
@@ -246,7 +242,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for adjoint_arg in self._adjoint_arg_options:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
x = self._make_x(
operator, adjoint=adjoint, with_batch=with_batch)
@@ -264,7 +260,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
self.assertAllEqual(op_matmul.get_shape(),
mat_matmul.get_shape())
op_matmul_v, mat_matmul_v = sess.run(
- [op_matmul, mat_matmul], feed_dict=feed_dict)
+ [op_matmul, mat_matmul])
self.assertAC(op_matmul_v, mat_matmul_v)
def test_matmul(self):
@@ -289,7 +285,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for adjoint_arg in self._adjoint_arg_options:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
rhs = self._make_rhs(
operator, adjoint=adjoint, with_batch=with_batch)
@@ -307,8 +303,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
if not use_placeholder:
self.assertAllEqual(op_solve.get_shape(),
mat_solve.get_shape())
- op_solve_v, mat_solve_v = sess.run(
- [op_solve, mat_solve], feed_dict=feed_dict)
+ op_solve_v, mat_solve_v = sess.run([op_solve, mat_solve])
self.assertAC(op_solve_v, mat_solve_v)
def test_solve(self):
@@ -326,14 +321,13 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_trace = operator.trace()
mat_trace = math_ops.trace(mat)
if not use_placeholder:
self.assertAllEqual(op_trace.get_shape(), mat_trace.get_shape())
- op_trace_v, mat_trace_v = sess.run(
- [op_trace, mat_trace], feed_dict=feed_dict)
+ op_trace_v, mat_trace_v = sess.run([op_trace, mat_trace])
self.assertAC(op_trace_v, mat_trace_v)
def test_add_to_tensor(self):
@@ -343,15 +337,14 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_plus_2mat = operator.add_to_tensor(2 * mat)
if not use_placeholder:
self.assertAllEqual(build_info.shape, op_plus_2mat.get_shape())
- op_plus_2mat_v, mat_v = sess.run(
- [op_plus_2mat, mat], feed_dict=feed_dict)
+ op_plus_2mat_v, mat_v = sess.run([op_plus_2mat, mat])
self.assertAC(op_plus_2mat_v, 3 * mat_v)
@@ -362,7 +355,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
for dtype in self._dtypes_to_test:
with self.test_session(graph=ops.Graph()) as sess:
sess.graph.seed = random_seed.DEFAULT_GRAPH_SEED
- operator, mat, feed_dict = self._operator_and_mat_and_feed_dict(
+ operator, mat = self._operator_and_matrix(
build_info, dtype, use_placeholder=use_placeholder)
op_diag_part = operator.diag_part()
mat_diag_part = array_ops.matrix_diag_part(mat)
@@ -372,7 +365,7 @@ class LinearOperatorDerivedClassTest(test.TestCase):
op_diag_part.get_shape())
op_diag_part_, mat_diag_part_ = sess.run(
- [op_diag_part, mat_diag_part], feed_dict=feed_dict)
+ [op_diag_part, mat_diag_part])
self.assertAC(op_diag_part_, mat_diag_part_)