aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-08-21 19:24:19 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 19:27:54 -0700
commit496023e9dc84a076caeb2e5e8e13b6a3d819ad6d (patch)
tree9776c9865f7b98a15817bc6be4c2b683323a67b1 /tensorflow/contrib/model_pruning
parent361a82d73a50a800510674b3aaa20e4845e56434 (diff)
Move from deprecated self.test_session() to self.cached_session().
self.test_session() has been deprecated in 9962eb5e84b15e309410071b06c2ed2d6148ed44 as its name confuses readers of the test. Moving to cached_session() instead which is more explicit about: * the fact that the session may be reused. * the session is not closed even when doing a "with self.test_session()" statement. PiperOrigin-RevId: 209701635
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py16
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py10
2 files changed, 13 insertions, 13 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 33c4ad58bd..cd3d8e76bb 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -61,14 +61,14 @@ class PruningHParamsTest(test.TestCase):
self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8)
def testInitWithExternalSparsity(self):
- with self.test_session():
+ with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
variables.global_variables_initializer().run()
sparsity = p._sparsity.eval()
self.assertAlmostEqual(sparsity, 0.5)
def testInitWithVariableReuse(self):
- with self.test_session():
+ with self.cached_session():
p = pruning.Pruning(spec=self.pruning_hparams, sparsity=self.sparsity)
p_copy = pruning.Pruning(
spec=self.pruning_hparams, sparsity=self.sparsity)
@@ -87,7 +87,7 @@ class PruningTest(test.TestCase):
def testCreateMask2D(self):
width = 10
height = 20
- with self.test_session():
+ with self.cached_session():
weights = variables.Variable(
random_ops.random_normal([width, height], stddev=1), name="weights")
masked_weights = pruning.apply_mask(weights,
@@ -98,7 +98,7 @@ class PruningTest(test.TestCase):
self.assertAllEqual(weights_val, masked_weights_val)
def testUpdateSingleMask(self):
- with self.test_session() as session:
+ with self.cached_session() as session:
weights = variables.Variable(
math_ops.linspace(1.0, 100.0, 100), name="weights")
masked_weights = pruning.apply_mask(weights)
@@ -122,7 +122,7 @@ class PruningTest(test.TestCase):
# Set up pruning
p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
_, new_mask = p._maybe_update_block_mask(weights, threshold)
# Check if the mask is the same size as the weights
@@ -167,7 +167,7 @@ class PruningTest(test.TestCase):
def testPartitionedVariableMasking(self):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
- with self.test_session() as session:
+ with self.cached_session() as session:
with variable_scope.variable_scope("", partitioner=partitioner):
sparsity = variables.Variable(0.5, name="Sparsity")
weights = variable_scope.get_variable(
@@ -201,7 +201,7 @@ class PruningTest(test.TestCase):
sparsity_val = math_ops.linspace(0.0, 0.9, 10)
increment_global_step = state_ops.assign_add(self.global_step, 1)
non_zero_count = []
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
for i in range(10):
session.run(state_ops.assign(sparsity, sparsity_val[i]))
@@ -234,7 +234,7 @@ class PruningTest(test.TestCase):
mask_update_op = p.conditional_mask_update_op()
increment_global_step = state_ops.assign_add(self.global_step, 1)
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
for _ in range(110):
session.run(mask_update_op)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index 06d7f97437..0aca843497 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -38,7 +38,7 @@ class PruningUtilsTest(test.TestCase):
def _compare_cdf(self, values):
abs_values = math_ops.abs(values)
max_value = math_ops.reduce_max(abs_values)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
@@ -55,7 +55,7 @@ class PruningUtilsTest(test.TestCase):
"weights", [width, height], initializer=init)
histogram = pruning_utils._histogram(
weights, [0, 1.0], nbins, dtype=np.float32)
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
computed_histogram = histogram.eval()
self.assertAllEqual(expected_histogram, computed_histogram)
@@ -67,7 +67,7 @@ class PruningUtilsTest(test.TestCase):
norm_cdf = pruning_utils.compute_cdf_from_histogram(
abs_weights, [0.0, 5.0], nbins=nbins)
expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32)
- with self.test_session() as sess:
+ with self.cached_session() as sess:
variables.global_variables_initializer().run()
norm_cdf_val = sess.run(norm_cdf)
self.assertAllEqual(len(norm_cdf_val), nbins)
@@ -90,7 +90,7 @@ class PruningUtilsTest(test.TestCase):
class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
def _compare_pooling_methods(self, weights, pooling_kwargs):
- with self.test_session():
+ with self.cached_session():
variables.global_variables_initializer().run()
pooled_weights_tf = array_ops.squeeze(
nn_ops.pool(
@@ -104,7 +104,7 @@ class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
pooled_weights_factorized_pool.eval())
def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
- with self.test_session() as session:
+ with self.cached_session() as session:
variables.global_variables_initializer().run()
expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
kronecker_product = pruning_utils.kronecker_product(