aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/factorization_ops_test.py')
-rw-r--r--tensorflow/contrib/factorization/python/ops/factorization_ops_test.py16
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
index bb5140aeb3..6aa62fb82e 100644
--- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
+++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py
@@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase):
observed *= num_rows / 3. if test_rows else num_cols / 2.
want_weight_sum = unobserved + observed
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
wals_model = factorization_ops.WALSModel(
input_rows=num_rows,
input_cols=num_cols,
@@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase):
def _run_test_process_input_transposed(self,
use_factors_weights_cache,
compute_loss=False):
- with ops.Graph().as_default(), self.test_session() as sess:
+ with ops.Graph().as_default(), self.cached_session() as sess:
self._wals_inputs = self.sparse_input()
sp_feeder = array_ops.sparse_placeholder(dtypes.float32)
num_rows = 5
@@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase):
# trigger the more efficient ALS updates.
# Here we test that those two give identical results.
def _run_test_als(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase):
atol=1e-2)
def _run_test_als_transposed(self, use_factors_weights_cache):
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
self._wals_inputs = self.sparse_input()
col_init = np.random.rand(7, 3)
als_model = factorization_ops.WALSModel(
@@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase):
rows = 15
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase):
cols = 11
dims = 3
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
data = np.dot(np.random.rand(rows, 3), np.random.rand(
3, cols)).astype(np.float32) / 3.0
indices = [[i, j] for i in xrange(rows) for j in xrange(cols)]
@@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase):
def keep_index(x):
return not (x[0] + x[1]) % 4
- with ops.Graph().as_default(), self.test_session():
+ with ops.Graph().as_default(), self.cached_session():
row_wts = 0.1 + np.random.rand(rows)
col_wts = 0.1 + np.random.rand(cols)
data = np.dot(np.random.rand(rows, 3), np.random.rand(