aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-09-25 09:08:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-25 09:13:12 -0700
commit588787ff7572208285cb471c76f4f8c83ad9d7ec (patch)
tree9837fcc446c38659757cadf00c2bedf4984eb730 /tensorflow/python/kernel_tests
parent9f300c2712340345570cf388c1a47fd771508ed8 (diff)
Use self.cached_session instead of self.test_session in linear_operator_circulant_test.
Also: * Instead of overwriting self.test_session(), overwrite self._constrain_devices_and_set_default() to remap the kernel operations (this way self.cached_session(), self.test_session() and self.session() are all correct). * Make linear_operator_test_util use self.session(graph=...) instead of self.test_session(graph=...) (semantically equivalent). PiperOrigin-RevId: 214448118
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py73
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
index 7261d4bb3b..f1e151ebd8 100644
--- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -37,8 +37,10 @@ class LinearOperatorCirculantBaseTest(object):
"""Common class for circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -110,8 +112,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype)
@@ -121,7 +122,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -171,8 +172,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -182,7 +182,7 @@ class LinearOperatorCirculantTestHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -217,8 +217,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant(
lin_op_spectrum, input_output_dtype=dtype)
@@ -228,7 +227,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
return operator, mat
def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -238,7 +237,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(
spectrum, input_output_dtype=dtypes.complex64)
@@ -250,7 +249,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
operator.assert_self_adjoint().run() # Should not fail
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = [1., 2., 1.]
spectrum = math_ops.fft(
math_ops.cast(convolution_kernel, dtypes.complex64))
@@ -266,7 +265,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
- with self.test_session():
+ with self.cached_session():
# Make spectrum the FFT of a real convolution kernel h. This ensures that
# spectrum is Hermitian.
h = linear_operator_test_util.random_normal(shape=(3, 4))
@@ -281,7 +280,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_convolution_kernel_same_as_first_row_of_to_dense(self):
spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
- with self.test_session():
+ with self.cached_session():
operator = linalg.LinearOperatorCirculant(spectrum)
h = operator.convolution_kernel()
c = operator.to_dense()
@@ -293,27 +292,27 @@ class LinearOperatorCirculantTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64)
operator = linalg.LinearOperatorCirculant(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -331,8 +330,10 @@ class LinearOperatorCirculant2DBaseTest(object):
"""Common class for 2D circulant tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
@@ -446,8 +447,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -482,8 +482,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
lin_op_spectrum = spectrum
if use_placeholder:
- lin_op_spectrum = array_ops.placeholder_with_default(
- spectrum, shape=None)
+ lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None)
operator = linalg.LinearOperatorCirculant2D(
lin_op_spectrum, input_output_dtype=dtype)
@@ -493,7 +492,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
return operator, mat
def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
operator = linalg.LinearOperatorCirculant(spectrum)
@@ -510,7 +509,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
self.assertAllClose(matrix, matrix_transpose, atol=0)
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(3, 3), dtype=dtypes.float32)
@@ -526,27 +525,27 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum(
def test_assert_non_singular_fails_for_singular_operator(self):
spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Singular operator"):
operator.assert_non_singular().run()
def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_non_singular().run() # Should not fail
def test_assert_positive_definite_fails_for_non_positive_definite(self):
spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
with self.assertRaisesOpError("Not positive definite"):
operator.assert_positive_definite().run()
def test_assert_positive_definite_does_not_fail_when_pos_def(self):
spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64)
operator = linalg.LinearOperatorCirculant2D(spectrum)
- with self.test_session():
+ with self.cached_session():
operator.assert_positive_definite().run() # Should not fail
def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
@@ -574,13 +573,15 @@ class LinearOperatorCirculant3DTest(test.TestCase):
"""Simple test of the 3D case. See also the 1D and 2D tests."""
@contextlib.contextmanager
- def test_session(self, *args, **kwargs):
- with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu):
+ """We overwrite the FFT operation mapping for testing."""
+ with test.TestCase._constrain_devices_and_set_default(
+ self, sess, use_gpu, force_gpu) as sess:
with spectral_ops_test_util.fft_kernel_label_map():
yield sess
def test_real_spectrum_gives_self_adjoint_operator(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# This is a real and hermitian spectrum.
spectrum = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
@@ -597,7 +598,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
self.assertAllClose(matrix, matrix_h)
def test_defining_operator_using_real_convolution_kernel(self):
- with self.test_session():
+ with self.cached_session():
convolution_kernel = linear_operator_test_util.random_normal(
shape=(2, 2, 3, 5), dtype=dtypes.float32)
# Convolution kernel is real ==> spectrum is Hermitian.
@@ -615,7 +616,7 @@ class LinearOperatorCirculant3DTest(test.TestCase):
np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
def test_defining_spd_operator_by_taking_real_part(self):
- with self.test_session() as sess:
+ with self.cached_session() as sess:
# S is real and positive.
s = linear_operator_test_util.random_uniform(
shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)