aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:53:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 20:00:41 -0700
commit47c0bda0e7f736a9328aaf76aba7c8006e24556f (patch)
treead2a6ab71adddc0d07c7f306c270122937b6a5b0 /tensorflow/contrib/factorization
parent1ab795b54274a26a92690f36eff65674fb500f91 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209703607
Diffstat (limited to 'tensorflow/contrib/factorization')
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py16
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py4
-rw-r--r--tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py4
3 files changed, 12 insertions, 12 deletions
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
index 1322f7ce5f..db47073fcc 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/clustering_ops_test.py
@@ -41,7 +41,7 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
[-1., -1.]]).astype(np.float32)
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
sampled_points = clustering_ops.kmeans_plus_plus_initialization(
self._points, 3, seed, (seed % 5) - 1)
self.assertAllClose(
@@ -58,7 +58,7 @@ class KmeansPlusPlusInitializationTest(test.TestCase):
class KMC2InitializationTest(test.TestCase):
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
distances = np.zeros(1000).astype(np.float32)
distances[6] = 10e7
distances[4] = 10e3
@@ -82,7 +82,7 @@ class KMC2InitializationLargeTest(test.TestCase):
self._distances[1000] = 50.0
def testBasic(self):
- with self.test_session():
+ with self.cached_session():
counts = {}
seed = 0
for i in range(50):
@@ -102,7 +102,7 @@ class KMC2InitializationCornercaseTest(test.TestCase):
self._distances = np.zeros(10)
def runTestWithSeed(self, seed):
- with self.test_session():
+ with self.cached_session():
sampled_point = clustering_ops.kmc2_chain_initialization(
self._distances, seed)
self.assertEquals(sampled_point.eval(), 0)
@@ -128,14 +128,14 @@ class NearestCentersTest(test.TestCase):
[1., 1.]]).astype(np.float32)
def testNearest1(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 1)
self.assertAllClose(indices.eval(), [[0], [0], [1], [4]])
self.assertAllClose(distances.eval(), [[0.], [5.], [1.], [0.]])
def testNearest2(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 2)
self.assertAllClose(indices.eval(), [[0, 1], [0, 1], [1, 0], [4, 3]])
@@ -180,7 +180,7 @@ class NearestCentersLargeTest(test.TestCase):
expected_nearest_neighbor_squared_distances))
def testNearest1(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 1)
self.assertAllClose(indices.eval(),
@@ -190,7 +190,7 @@ class NearestCentersLargeTest(test.TestCase):
self._expected_nearest_neighbor_squared_distances[:, [0]])
def testNearest5(self):
- with self.test_session():
+ with self.cached_session():
[indices, distances] = clustering_ops.nearest_neighbors(self._points,
self._centers, 5)
self.assertAllClose(indices.eval(),
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
index 3a909e2373..dd115735d0 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/masked_matmul_ops_test.py
@@ -58,7 +58,7 @@ class MaskedProductOpsTest(test.TestCase):
self._mask_ind, self._mask_shape = MakeMask()
def _runTestMaskedProduct(self, transpose_a, transpose_b):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
a = self._a if not transpose_a else array_ops.transpose(self._a)
b = self._b if not transpose_b else array_ops.transpose(self._b)
@@ -78,7 +78,7 @@ class MaskedProductOpsTest(test.TestCase):
AssertClose(result, true_result)
def _runTestEmptyMaskedProduct(self):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
empty_mask = constant_op.constant(0, shape=[0, 2], dtype=dtypes.int64)
values = gen_factorization_ops.masked_matmul(
self._a, self._b, empty_mask, False, False)
diff --git a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
index 6c2f1d4608..8a16e22663 100644
--- a/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
+++ b/tensorflow/contrib/factorization/python/kernel_tests/wals_solver_ops_test.py
@@ -50,7 +50,7 @@ class WalsSolverOpsTest(test.TestCase):
def testWalsSolverLhs(self):
sparse_block = SparseBlock3x3()
- with self.test_session():
+ with self.cached_session():
[lhs_tensor,
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, self._column_weights, self._unobserved_weights,
@@ -82,7 +82,7 @@ class WalsSolverOpsTest(test.TestCase):
def testWalsSolverLhsEntryWeights(self):
sparse_block = SparseBlock3x3()
- with self.test_session():
+ with self.cached_session():
[lhs_tensor,
rhs_matrix] = gen_factorization_ops.wals_compute_partial_lhs_and_rhs(
self._column_factors, [], self._unobserved_weights,