aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/model_pruning/python')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py21
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py17
2 files changed, 30 insertions, 8 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index d16af9da19..86963be4b8 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -523,7 +523,8 @@ class Pruning(object):
"""Performs block-granular masking of the weights.
Block pruning occurs only if the block_height or block_width is > 1 and
- if the weight tensor has ndims = 2. Otherwise, elementwise pruning occurs.
+ if the weight tensor, when squeezed, has ndims = 2. Otherwise, elementwise
+ pruning occurs.
Args:
weights: The weight tensor that needs to be masked.
threshold: The current threshold value. The function will compute a new
@@ -540,7 +541,8 @@ class Pruning(object):
Raises:
ValueError: if block pooling function is not AVG or MAX
"""
- if weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
+ squeezed_weights = array_ops.squeeze(weights)
+ if squeezed_weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
return self._update_mask(weights, threshold)
if self._block_pooling_function not in ['AVG', 'MAX']:
@@ -549,9 +551,11 @@ class Pruning(object):
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(
- array_ops.reshape(
- weights, [1, weights.get_shape()[0],
- weights.get_shape()[1], 1]))
+ array_ops.reshape(weights, [
+ 1,
+ squeezed_weights.get_shape()[0],
+ squeezed_weights.get_shape()[1], 1
+ ]))
pool_window = [self._block_dim[0], self._block_dim[1]]
pooled_weights = nn_ops.pool(
abs_weights,
@@ -572,9 +576,10 @@ class Pruning(object):
array_ops.ones(self._block_dim))
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
- [weights.get_shape()[0],
- weights.get_shape()[1]])
- return smoothed_threshold, sliced_mask
+ [squeezed_weights.get_shape()[0],
+ squeezed_weights.get_shape()[1]])
+ return smoothed_threshold, array_ops.reshape(sliced_mask,
+ array_ops.shape(weights))
def _get_mask_assign_ops(self):
# Make sure the assignment ops have not already been added to the list
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 1767b4bb94..89e6571319 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -142,6 +142,23 @@ class PruningTest(test.TestCase):
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
+ self._blockMasking(param_list + ["block_pooling_function=AVG"], weights_avg,
+ expected_mask)
+
+ def testBlockMaskingWithHigherDimensions(self):
+ param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
+
+ # Weights as in testBlockMasking, but with one extra dimension.
+ weights_avg = constant_op.constant(
+ [[[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
+ [0.3, 0.3, 0.4, 0.4]]])
+ weights_max = constant_op.constant(
+ [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
+ [0.0, -0.3, 0.0, -0.4]]])
+ expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]]
+
+ self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
+ expected_mask)
self._blockMasking(param_list + ["block_pooling_function=AVG"],
weights_avg, expected_mask)