diff options
Diffstat (limited to 'tensorflow/contrib/factorization/python/ops/factorization_ops_test.py')
-rw-r--r-- | tensorflow/contrib/factorization/python/ops/factorization_ops_test.py | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index bb5140aeb3..6aa62fb82e 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase): observed *= num_rows / 3. if test_rows else num_cols / 2. want_weight_sum = unobserved + observed - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: wals_model = factorization_ops.WALSModel( input_rows=num_rows, input_cols=num_cols, @@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input_transposed(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase): # trigger the more efficient ALS updates. # Here we test that those two give identical results. def _run_test_als(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase): atol=1e-2) def _run_test_als_transposed(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase): rows = 15 cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase): cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase): def keep_index(x): return not (x[0] + x[1]) % 4 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) data = np.dot(np.random.rand(rows, 3), np.random.rand( |