diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-09-25 09:08:55 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-25 09:13:12 -0700 |
commit | 588787ff7572208285cb471c76f4f8c83ad9d7ec (patch) | |
tree | 9837fcc446c38659757cadf00c2bedf4984eb730 /tensorflow/python/kernel_tests | |
parent | 9f300c2712340345570cf388c1a47fd771508ed8 (diff) |
Use self.cached_session instead of self.test_session in linear_operator_circulant_test.
Also:
* Instead of overwriting self.test_session(), overwrite self._constrain_devices_and_set_default() to remap the kernel operations (this way self.cached_session(), self.test_session() and self.session() are all correct).
* Make linear_operator_test_util use self.session(graph=...) instead of self.test_session(graph=...) (semantically equivalent).
PiperOrigin-RevId: 214448118
Diffstat (limited to 'tensorflow/python/kernel_tests')
-rw-r--r-- | tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py | 73 |
1 files changed, 37 insertions, 36 deletions
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py index 7261d4bb3b..f1e151ebd8 100644 --- a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py +++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py @@ -37,8 +37,10 @@ class LinearOperatorCirculantBaseTest(object): """Common class for circulant tests.""" @contextlib.contextmanager - def test_session(self, *args, **kwargs): - with test.TestCase.test_session(self, *args, **kwargs) as sess: + def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + """We overwrite the FFT operation mapping for testing.""" + with test.TestCase._constrain_devices_and_set_default( + self, sess, use_gpu, force_gpu) as sess: with spectral_ops_test_util.fft_kernel_label_map(): yield sess @@ -110,8 +112,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( lin_op_spectrum = spectrum if use_placeholder: - lin_op_spectrum = array_ops.placeholder_with_default( - spectrum, shape=None) + lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) operator = linalg.LinearOperatorCirculant( lin_op_spectrum, is_self_adjoint=True, input_output_dtype=dtype) @@ -121,7 +122,7 @@ class LinearOperatorCirculantTestSelfAdjointOperator( return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): - with self.test_session(): + with self.cached_session(): spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64) operator = linalg.LinearOperatorCirculant( spectrum, input_output_dtype=dtypes.complex64) @@ -171,8 +172,7 @@ class LinearOperatorCirculantTestHermitianSpectrum( lin_op_spectrum = spectrum if use_placeholder: - lin_op_spectrum = array_ops.placeholder_with_default( - spectrum, shape=None) + lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) operator = linalg.LinearOperatorCirculant( lin_op_spectrum, input_output_dtype=dtype) @@ -182,7 +182,7 @@ class LinearOperatorCirculantTestHermitianSpectrum( return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): - with self.test_session(): + with self.cached_session(): spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64) operator = linalg.LinearOperatorCirculant( spectrum, input_output_dtype=dtypes.complex64) @@ -217,8 +217,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( lin_op_spectrum = spectrum if use_placeholder: - lin_op_spectrum = array_ops.placeholder_with_default( - spectrum, shape=None) + lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) operator = linalg.LinearOperatorCirculant( lin_op_spectrum, input_output_dtype=dtype) @@ -228,7 +227,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( return operator, mat def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self): - with self.test_session(): + with self.cached_session(): spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64) operator = linalg.LinearOperatorCirculant( spectrum, input_output_dtype=dtypes.complex64) @@ -238,7 +237,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3) def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self): - with self.test_session() as sess: + with self.cached_session() as sess: spectrum = math_ops.cast([6., 4, 2], dtypes.complex64) operator = linalg.LinearOperatorCirculant( spectrum, input_output_dtype=dtypes.complex64) @@ -250,7 +249,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( operator.assert_self_adjoint().run() # Should not fail def test_defining_operator_using_real_convolution_kernel(self): - with self.test_session(): + with self.cached_session(): convolution_kernel = [1., 2., 1.] spectrum = math_ops.fft( math_ops.cast(convolution_kernel, dtypes.complex64)) @@ -266,7 +265,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6) def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self): - with self.test_session(): + with self.cached_session(): # Make spectrum the FFT of a real convolution kernel h. This ensures that # spectrum is Hermitian. h = linear_operator_test_util.random_normal(shape=(3, 4)) @@ -281,7 +280,7 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( def test_convolution_kernel_same_as_first_row_of_to_dense(self): spectrum = [[3., 2., 1.], [2., 1.5, 1.]] - with self.test_session(): + with self.cached_session(): operator = linalg.LinearOperatorCirculant(spectrum) h = operator.convolution_kernel() c = operator.to_dense() @@ -293,27 +292,27 @@ class LinearOperatorCirculantTestNonHermitianSpectrum( def test_assert_non_singular_fails_for_singular_operator(self): spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64) operator = linalg.LinearOperatorCirculant(spectrum) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Singular operator"): operator.assert_non_singular().run() def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64) operator = linalg.LinearOperatorCirculant(spectrum) - with self.test_session(): + with self.cached_session(): operator.assert_non_singular().run() # Should not fail def test_assert_positive_definite_fails_for_non_positive_definite(self): spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64) operator = linalg.LinearOperatorCirculant(spectrum) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Not positive definite"): operator.assert_positive_definite().run() def test_assert_positive_definite_does_not_fail_when_pos_def(self): spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64) operator = linalg.LinearOperatorCirculant(spectrum) - with self.test_session(): + with self.cached_session(): operator.assert_positive_definite().run() # Should not fail def test_real_spectrum_and_not_self_adjoint_hint_raises(self): @@ -331,8 +330,10 @@ class LinearOperatorCirculant2DBaseTest(object): """Common class for 2D circulant tests.""" @contextlib.contextmanager - def test_session(self, *args, **kwargs): - with test.TestCase.test_session(self, *args, **kwargs) as sess: + def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + """We overwrite the FFT operation mapping for testing.""" + with test.TestCase._constrain_devices_and_set_default( + self, sess, use_gpu, force_gpu) as sess: with spectral_ops_test_util.fft_kernel_label_map(): yield sess @@ -446,8 +447,7 @@ class LinearOperatorCirculant2DTestHermitianSpectrum( lin_op_spectrum = spectrum if use_placeholder: - lin_op_spectrum = array_ops.placeholder_with_default( - spectrum, shape=None) + lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) operator = linalg.LinearOperatorCirculant2D( lin_op_spectrum, input_output_dtype=dtype) @@ -482,8 +482,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( lin_op_spectrum = spectrum if use_placeholder: - lin_op_spectrum = array_ops.placeholder_with_default( - spectrum, shape=None) + lin_op_spectrum = array_ops.placeholder_with_default(spectrum, shape=None) operator = linalg.LinearOperatorCirculant2D( lin_op_spectrum, input_output_dtype=dtype) @@ -493,7 +492,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( return operator, mat def test_real_hermitian_spectrum_gives_real_symmetric_operator(self): - with self.test_session() as sess: + with self.cached_session() as sess: # This is a real and hermitian spectrum. spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]] operator = linalg.LinearOperatorCirculant(spectrum) @@ -510,7 +509,7 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( self.assertAllClose(matrix, matrix_transpose, atol=0) def test_real_spectrum_gives_self_adjoint_operator(self): - with self.test_session() as sess: + with self.cached_session() as sess: # This is a real and hermitian spectrum. spectrum = linear_operator_test_util.random_normal( shape=(3, 3), dtype=dtypes.float32) @@ -526,27 +525,27 @@ class LinearOperatorCirculant2DTestNonHermitianSpectrum( def test_assert_non_singular_fails_for_singular_operator(self): spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64) operator = linalg.LinearOperatorCirculant2D(spectrum) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Singular operator"): operator.assert_non_singular().run() def test_assert_non_singular_does_not_fail_for_non_singular_operator(self): spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64) operator = linalg.LinearOperatorCirculant2D(spectrum) - with self.test_session(): + with self.cached_session(): operator.assert_non_singular().run() # Should not fail def test_assert_positive_definite_fails_for_non_positive_definite(self): spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64) operator = linalg.LinearOperatorCirculant2D(spectrum) - with self.test_session(): + with self.cached_session(): with self.assertRaisesOpError("Not positive definite"): operator.assert_positive_definite().run() def test_assert_positive_definite_does_not_fail_when_pos_def(self): spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64) operator = linalg.LinearOperatorCirculant2D(spectrum) - with self.test_session(): + with self.cached_session(): operator.assert_positive_definite().run() # Should not fail def test_real_spectrum_and_not_self_adjoint_hint_raises(self): @@ -574,13 +573,15 @@ class LinearOperatorCirculant3DTest(test.TestCase): """Simple test of the 3D case. See also the 1D and 2D tests.""" @contextlib.contextmanager - def test_session(self, *args, **kwargs): - with test.TestCase.test_session(self, *args, **kwargs) as sess: + def _constrain_devices_and_set_default(self, sess, use_gpu, force_gpu): + """We overwrite the FFT operation mapping for testing.""" + with test.TestCase._constrain_devices_and_set_default( + self, sess, use_gpu, force_gpu) as sess: with spectral_ops_test_util.fft_kernel_label_map(): yield sess def test_real_spectrum_gives_self_adjoint_operator(self): - with self.test_session() as sess: + with self.cached_session() as sess: # This is a real and hermitian spectrum. spectrum = linear_operator_test_util.random_normal( shape=(2, 2, 3, 5), dtype=dtypes.float32) @@ -597,7 +598,7 @@ class LinearOperatorCirculant3DTest(test.TestCase): self.assertAllClose(matrix, matrix_h) def test_defining_operator_using_real_convolution_kernel(self): - with self.test_session(): + with self.cached_session(): convolution_kernel = linear_operator_test_util.random_normal( shape=(2, 2, 3, 5), dtype=dtypes.float32) # Convolution kernel is real ==> spectrum is Hermitian. @@ -615,7 +616,7 @@ class LinearOperatorCirculant3DTest(test.TestCase): np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6) def test_defining_spd_operator_by_taking_real_part(self): - with self.test_session() as sess: + with self.cached_session() as sess: # S is real and positive. s = linear_operator_test_util.random_uniform( shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.) |