aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-14 18:42:59 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-14 18:48:24 -0800
commitaadc84cce45cccce0c6967cbb50793276bcf4874 (patch)
tree85fc7db1dddb5f35bdc214560e843b424474d9d6 /tensorflow/contrib/model_pruning
parent147f4acd4b4f7b1c81d780adce698b2056837796 (diff)
Add block sparsity support for 2D weight tensors only.
PiperOrigin-RevId: 179130257
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md14
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py100
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py34
3 files changed, 143 insertions, 5 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 764e126e0d..d286750c25 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -42,10 +42,13 @@ The pruning library allows for specification of the following hyper parameters:
| name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope |
| begin_pruning_step | integer | 0 | The global step at which to begin pruning |
| end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops |
-| do_not_prune | list of strings | [""] | list of layers strings that are not pruned |
+| do_not_prune | list of strings | [""] | list of layers names that are not pruned |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
| nbins | integer | 255 | Number of bins to use for histogram computation |
+| block_height|integer | 1 | Number of rows in a block for block sparse matrices|
+| block_width |integer | 1 | Number of cols in a block for block sparse matrices|
+| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
| initial_sparsity | float | 0.0 | Initial sparsity value |
| target_sparsity | float | 0.5 | Target sparsity value |
| sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect |
@@ -128,3 +131,12 @@ Eval:
```shell
$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
```
+
+### Block Sparsity
+
+For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is supported for weight tensors with rank 2 only. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
+The convolution layer tensors are always pruned used block dimensions of [1,1].
+
+## References
+
+Michael Zhu and Suyog Gupta, “To prune, or not to prune: exploring the efficacy of pruning for model compression”, *2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices* (https://arxiv.org/pdf/1710.01878.pdf)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 39eb79daf0..d16af9da19 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -72,6 +72,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -130,6 +131,23 @@ def _weight_threshold_variable(var, scope):
return threshold
+def _kronecker_product(mat1, mat2):
+ """Computes the Kronecker product of two matrices mat1 and mat2.
+
+ Args:
+ mat1: A matrix of size m x n
+ mat2: A matrix of size p x q
+ Returns:
+ Kronecker product of matrices mat1 and mat2 of size mp x nq
+ """
+
+ m1, n1 = mat1.get_shape().as_list()
+ mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
+ m2, n2 = mat2.get_shape().as_list()
+ mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
+ return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+
+
def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
"""Return histogram of values.
@@ -298,6 +316,13 @@ def get_pruning_hparams():
How often should the masks be updated? (in # of global_steps)
nbins: integer
number of bins to use for histogram computation
+ block_height: integer
+ number of rows in a block (defaults to 1)
+ block_width: integer
+ number of cols in a block (defaults to 1)
+ block_pooling_function: string
+ Whether to perform average (AVG) or max (MAX) pooling in the block
+ (default: AVG)
initial_sparsity: float
initial sparsity value
target_sparsity: float
@@ -333,6 +358,9 @@ def get_pruning_hparams():
threshold_decay=0.9,
pruning_frequency=10,
nbins=255,
+ block_height=1,
+ block_width=1,
+ block_pooling_function='AVG',
initial_sparsity=0,
target_sparsity=0.5,
sparsity_function_begin_step=0,
@@ -375,6 +403,12 @@ class Pruning(object):
# were updated
self._last_update_step = self._setup_last_update_step()
+ # Block dimensions
+ self._block_dim = [self._spec.block_height, self._spec.block_width]
+
+ # Block pooling function
+ self._block_pooling_function = self._spec.block_pooling_function
+
def _setup_global_step(self, global_step):
graph_global_step = global_step
if graph_global_step is None:
@@ -449,9 +483,10 @@ class Pruning(object):
Returns:
new_threshold: The new value of the threshold based on weights, and
- desired_sparsity
- new_mask: A n-D numpy array containing 0 or 1 to indicate which of the
- values in weights falls below the threshold
+ sparsity at the current global_step
+ new_mask: A numpy array of the same size and shape as weights containing
+ 0 or 1 to indicate which of the values in weights falls below
+ the threshold
Raises:
ValueError: if sparsity is not defined
@@ -484,6 +519,63 @@ class Pruning(object):
math_ops.greater(abs_weights, smoothed_threshold), np.float32)
return smoothed_threshold, new_mask
+ def _maybe_update_block_mask(self, weights, threshold):
+ """Performs block-granular masking of the weights.
+
+ Block pruning occurs only if the block_height or block_width is > 1 and
+ if the weight tensor has ndims = 2. Otherwise, elementwise pruning occurs.
+ Args:
+ weights: The weight tensor that needs to be masked.
+ threshold: The current threshold value. The function will compute a new
+ threshold and return the exponential moving average using the current
+ value of threshold
+
+ Returns:
+ new_threshold: The new value of the threshold based on weights, and
+ sparsity at the current global_step
+ new_mask: A numpy array of the same size and shape as weights containing
+ 0 or 1 to indicate which of the values in weights falls below
+ the threshold
+
+ Raises:
+ ValueError: if block pooling function is not AVG or MAX
+ """
+ if weights.get_shape().ndims != 2 or self._block_dim == [1, 1]:
+ return self._update_mask(weights, threshold)
+
+ if self._block_pooling_function not in ['AVG', 'MAX']:
+ raise ValueError('Unknown pooling function for block sparsity: %s' %
+ self._block_pooling_function)
+
+ with ops.name_scope(weights.op.name + '_pruning_ops'):
+ abs_weights = math_ops.abs(
+ array_ops.reshape(
+ weights, [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]))
+ pool_window = [self._block_dim[0], self._block_dim[1]]
+ pooled_weights = nn_ops.pool(
+ abs_weights,
+ window_shape=pool_window,
+ pooling_type=self._block_pooling_function,
+ strides=pool_window,
+ padding='SAME',
+ name=weights.op.name + '_pooled')
+
+ smoothed_threshold, new_mask = self._update_mask(pooled_weights,
+ threshold)
+
+ reshaped_mask = array_ops.reshape(
+ new_mask,
+ [pooled_weights.get_shape()[1],
+ pooled_weights.get_shape()[2]])
+ updated_mask = _kronecker_product(reshaped_mask,
+ array_ops.ones(self._block_dim))
+ sliced_mask = array_ops.slice(
+ updated_mask, [0, 0],
+ [weights.get_shape()[0],
+ weights.get_shape()[1]])
+ return smoothed_threshold, sliced_mask
+
def _get_mask_assign_ops(self):
# Make sure the assignment ops have not already been added to the list
if self._assign_ops:
@@ -510,7 +602,7 @@ class Pruning(object):
if self._exists_in_do_not_prune_list(mask.name):
continue
- new_threshold, new_mask = self._update_mask(weight, threshold)
+ new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
self._assign_ops.append(_variable_assign(threshold, new_threshold))
self._assign_ops.append(
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 34b4584f49..1767b4bb94 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning
+from tensorflow.python.framework import constant_op
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import random_ops
@@ -111,6 +112,39 @@ class PruningTest(test.TestCase):
masked_weights_val = masked_weights.eval()
self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+ def _blockMasking(self, hparams, weights, expected_mask):
+
+ threshold = variables.Variable(0.0, name="threshold")
+ sparsity = variables.Variable(0.51, name="sparsity")
+ test_spec = ",".join(hparams)
+ pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
+
+ # Set up pruning
+ p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ _, new_mask = p._maybe_update_block_mask(weights, threshold)
+ # Check if the mask is the same size as the weights
+ self.assertAllEqual(new_mask.get_shape(), weights.get_shape())
+ mask_val = new_mask.eval()
+ self.assertAllEqual(mask_val, expected_mask)
+
+ def testBlockMasking(self):
+ param_list = ["block_height=2", "block_width=2", "threshold_decay=0"]
+
+ weights_avg = constant_op.constant(
+ [[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4],
+ [0.3, 0.3, 0.4, 0.4]])
+ weights_max = constant_op.constant(
+ [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
+ [0.0, -0.3, 0.0, -0.4]])
+ expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
+
+ self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
+ expected_mask)
+ self._blockMasking(param_list + ["block_pooling_function=AVG"],
+ weights_avg, expected_mask)
+
def testPartitionedVariableMasking(self):
partitioner = partitioned_variables.variable_axis_size_partitioner(40)
with self.test_session() as session: