diff options
author | 2018-09-10 15:43:51 -0700 | |
---|---|---|
committer | 2018-09-10 15:49:05 -0700 | |
commit | e32029541ae270a021b266fcc3929b2528f8dff1 (patch) | |
tree | 40abaa2485e86b41b10d317af3969754e5cdb789 /tensorflow/contrib/factorization | |
parent | 6951e0646d7dc8931b6cbe4388dcc3921249d462 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 212348850
Diffstat (limited to 'tensorflow/contrib/factorization')
4 files changed, 16 insertions, 16 deletions
diff --git a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py index bb5140aeb3..6aa62fb82e 100644 --- a/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/factorization_ops_test.py @@ -126,7 +126,7 @@ class WalsModelTest(test.TestCase): observed *= num_rows / 3. if test_rows else num_cols / 2. want_weight_sum = unobserved + observed - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: wals_model = factorization_ops.WALSModel( input_rows=num_rows, input_cols=num_cols, @@ -161,7 +161,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -330,7 +330,7 @@ class WalsModelTest(test.TestCase): def _run_test_process_input_transposed(self, use_factors_weights_cache, compute_loss=False): - with ops.Graph().as_default(), self.test_session() as sess: + with ops.Graph().as_default(), self.cached_session() as sess: self._wals_inputs = self.sparse_input() sp_feeder = array_ops.sparse_placeholder(dtypes.float32) num_rows = 5 @@ -505,7 +505,7 @@ class WalsModelTest(test.TestCase): # trigger the more efficient ALS updates. # Here we test that those two give identical results. def _run_test_als(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -583,7 +583,7 @@ class WalsModelTest(test.TestCase): atol=1e-2) def _run_test_als_transposed(self, use_factors_weights_cache): - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): self._wals_inputs = self.sparse_input() col_init = np.random.rand(7, 3) als_model = factorization_ops.WALSModel( @@ -673,7 +673,7 @@ class WalsModelTest(test.TestCase): rows = 15 cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -703,7 +703,7 @@ class WalsModelTest(test.TestCase): cols = 11 dims = 3 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): data = np.dot(np.random.rand(rows, 3), np.random.rand( 3, cols)).astype(np.float32) / 3.0 indices = [[i, j] for i in xrange(rows) for j in xrange(cols)] @@ -736,7 +736,7 @@ class WalsModelTest(test.TestCase): def keep_index(x): return not (x[0] + x[1]) % 4 - with ops.Graph().as_default(), self.test_session(): + with ops.Graph().as_default(), self.cached_session(): row_wts = 0.1 + np.random.rand(rows) col_wts = 0.1 + np.random.rand(cols) data = np.dot(np.random.rand(rows, 3), np.random.rand( diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py index 888c3c238c..112e4d289b 100644 --- a/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py +++ b/tensorflow/contrib/factorization/python/ops/gmm_ops_test.py @@ -99,7 +99,7 @@ class GmmOpsTest(test.TestCase): logging.info('Numpy took %f', time.time() - start_time) start_time = time.time() - with self.test_session() as sess: + with self.cached_session() as sess: op = gmm_ops._covariance( constant_op.constant( data.T, dtype=dtypes.float32), False) @@ -120,7 +120,7 @@ class GmmOpsTest(test.TestCase): graph = ops.Graph() with graph.as_default() as g: g.seed = 5 - with self.test_session() as sess: + with self.cached_session() as sess: data = constant_op.constant(self.data, dtype=dtypes.float32) loss_op, scores, assignments, training_op, init_op, _ = gmm_ops.gmm( data, 'random', num_classes, random_seed=self.seed) @@ -144,7 +144,7 @@ class GmmOpsTest(test.TestCase): def testParams(self): """Tests that the params work as intended.""" num_classes = 2 - with self.test_session() as sess: + with self.cached_session() as sess: # Experiment 1. Update weights only. data = constant_op.constant(self.data, dtype=dtypes.float32) gmm_tool = gmm_ops.GmmAlgorithm([data], num_classes, diff --git a/tensorflow/contrib/factorization/python/ops/kmeans_test.py b/tensorflow/contrib/factorization/python/ops/kmeans_test.py index 88eb9cf692..1ab5418fe4 100644 --- a/tensorflow/contrib/factorization/python/ops/kmeans_test.py +++ b/tensorflow/contrib/factorization/python/ops/kmeans_test.py @@ -232,7 +232,7 @@ class KMeansTest(KMeansTestBase): self.assertEqual(features.shape, parsed_feature_dict.shape) self.assertEqual(features.dtype, parsed_feature_dict.dtype) # Then check that running the tensor yields the original list of points. - with self.test_session() as sess: + with self.cached_session() as sess: parsed_points = sess.run(parsed_feature_dict) self.assertAllEqual(self.points, parsed_points) diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py index 31820a18b4..9bdbd05015 100644 --- a/tensorflow/contrib/factorization/python/ops/wals_test.py +++ b/tensorflow/contrib/factorization/python/ops/wals_test.py @@ -336,7 +336,7 @@ class WALSMatrixFactorizationTest(test.TestCase): loss = self._model.evaluate( input_fn=eval_input_fn_row, steps=self._num_rows)['loss'] - with self.test_session(): + with self.cached_session(): true_loss = self.calculate_loss() self.assertNear( @@ -354,7 +354,7 @@ class WALSMatrixFactorizationTest(test.TestCase): loss = self._model.evaluate( input_fn=eval_input_fn_col, steps=self._num_cols)['loss'] - with self.test_session(): + with self.cached_session(): true_loss = self.calculate_loss() self.assertNear( @@ -440,7 +440,7 @@ class SweepHookTest(test.TestCase): math_ops.logical_not(is_row_sweep_var))) mark_sweep_done = state_ops.assign(is_sweep_done_var, True) - with self.test_session() as sess: + with self.cached_session() as sess: sweep_hook = wals_lib._SweepHook( is_row_sweep_var, is_sweep_done_var, @@ -491,7 +491,7 @@ class StopAtSweepHookTest(test.TestCase): train_op = state_ops.assign_add(completed_sweeps, 1) hook.begin() - with self.test_session() as sess: + with self.cached_session() as sess: sess.run([variables.global_variables_initializer()]) mon_sess = monitored_session._HookedSession(sess, [hook]) mon_sess.run(train_op) |