diff options
author | 2018-08-21 19:24:19 -0700 | |
---|---|---|
committer | 2018-08-21 19:27:54 -0700 | |
commit | 496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (patch) | |
tree | 9776c9865f7b98a15817bc6be4c2b683323a67b1 /tensorflow/contrib/model_pruning | |
parent | 361a82d73a50a800510674b3aaa20e4845e56434 (diff) |
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about:
* the fact that the session may be reused.
* the session is not closed even when doing a "with self.test_session()" statement.
PiperOrigin-RevId: 209701635
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_test.py | 16 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_utils_test.py | 10 |
2 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 33c4ad58bd..cd3d8e76bb 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -61,14 +61,14 @@ class PruningHParamsTest(test.TestCase): self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8) def testInitWithExternalSparsity(self): - with self.test_session(): + with self.cached_session(): p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) variables.global_variables_initializer().run() sparsity = p._sparsity.eval() self.assertAlmostEqual(sparsity, 0.5) def testInitWithVariableReuse(self): - with self.test_session(): + with self.cached_session(): p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity) p_copy = pruning.Pruning( spec=self.pruning_hparams, sparsity=self.sparsity) @@ -87,7 +87,7 @@ class PruningTest(test.TestCase): def testCreateMask2D(self): width = 10 height = 20 - with self.test_session(): + with self.cached_session(): weights = variables.Variable( random_ops.random_normal([width, height], stddev=1), name="weights") masked_weights = pruning.apply_mask(weights, @@ -98,7 +98,7 @@ class PruningTest(test.TestCase): self.assertAllEqual(weights_val, masked_weights_val) def testUpdateSingleMask(self): - with self.test_session() as session: + with self.cached_session() as session: weights = variables.Variable( math_ops.linspace(1.0, 100.0, 100), name="weights") masked_weights = pruning.apply_mask(weights) @@ -122,7 +122,7 @@ class PruningTest(test.TestCase): # Set up pruning p = pruning.Pruning(pruning_hparams, sparsity=sparsity) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() _, new_mask = p._maybe_update_block_mask(weights, threshold) # Check if the mask is the same size as the weights @@ -167,7 +167,7 @@ class PruningTest(test.TestCase): def testPartitionedVariableMasking(self): partitioner = partitioned_variables.variable_axis_size_partitioner(40) - with self.test_session() as session: + with self.cached_session() as session: with variable_scope.variable_scope("", partitioner=partitioner): sparsity = variables.Variable(0.5, name="Sparsity") weights = variable_scope.get_variable( @@ -201,7 +201,7 @@ class PruningTest(test.TestCase): sparsity_val = math_ops.linspace(0.0, 0.9, 10) increment_global_step = state_ops.assign_add(self.global_step, 1) non_zero_count = [] - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() for i in range(10): session.run(state_ops.assign(sparsity, sparsity_val[i])) @@ -234,7 +234,7 @@ class PruningTest(test.TestCase): mask_update_op = p.conditional_mask_update_op() increment_global_step = state_ops.assign_add(self.global_step, 1) - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() for _ in range(110): session.run(mask_update_op) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index 06d7f97437..0aca843497 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -38,7 +38,7 @@ class PruningUtilsTest(test.TestCase): def _compare_cdf(self, values): abs_values = math_ops.abs(values) max_value = math_ops.reduce_max(abs_values) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) @@ -55,7 +55,7 @@ class PruningUtilsTest(test.TestCase): "weights", [width, height], initializer=init) histogram = pruning_utils._histogram( weights, [0, 1.0], nbins, dtype=np.float32) - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() computed_histogram = histogram.eval() self.assertAllEqual(expected_histogram, computed_histogram) @@ -67,7 +67,7 @@ class PruningUtilsTest(test.TestCase): norm_cdf = pruning_utils.compute_cdf_from_histogram( abs_weights, [0.0, 5.0], nbins=nbins) expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32) - with self.test_session() as sess: + with self.cached_session() as sess: variables.global_variables_initializer().run() norm_cdf_val = sess.run(norm_cdf) self.assertAllEqual(len(norm_cdf_val), nbins) @@ -90,7 +90,7 @@ class PruningUtilsTest(test.TestCase): class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): def _compare_pooling_methods(self, weights, pooling_kwargs): - with self.test_session(): + with self.cached_session(): variables.global_variables_initializer().run() pooled_weights_tf = array_ops.squeeze( nn_ops.pool( @@ -104,7 +104,7 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): pooled_weights_factorized_pool.eval()) def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): - with self.test_session() as session: + with self.cached_session() as session: variables.global_variables_initializer().run() expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) kronecker_product = pruning_utils.kronecker_product( |