diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-12-14 18:42:59 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-12-14 18:48:24 -0800 |
commit | aadc84cce45cccce0c6967cbb50793276bcf4874 (patch) | |
tree | 85fc7db1dddb5f35bdc214560e843b424474d9d6 /tensorflow/contrib/model_pruning | |
parent | 147f4acd4b4f7b1c81d780adce698b2056837796 (diff) |
Add block sparsity support for 2D weight tensors only.
PiperOrigin-RevId: 179130257
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/README.md | 14 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning.py | 100 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_test.py | 34 |
3 files changed, 143 insertions, 5 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 764e126e0d..d286750c25 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -42,10 +42,13 @@ The pruning library allows for specification of the following hyper parameters: | name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope | | begin_pruning_step | integer | 0 | The global step at which to begin pruning | | end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops | -| do_not_prune | list of strings | [""] | list of layers strings that are not pruned | +| do_not_prune | list of strings | [""] | list of layers names that are not pruned | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | | nbins | integer | 255 | Number of bins to use for histogram computation | +| block_height|integer | 1 | Number of rows in a block for block sparse matrices| +| block_width |integer | 1 | Number of cols in a block for block sparse matrices| +| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| | initial_sparsity | float | 0.0 | Initial sparsity value | | target_sparsity | float | 0.5 | Target sparsity value | | sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect | @@ -128,3 +131,12 @@ Eval: ```shell $ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once ``` + +### Block Sparsity + +For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is supported for weight tensors with rank 2 only. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter). +The convolution layer tensors are always pruned used block dimensions of [1,1]. + +## References + +Michael Zhu and Suyog Gupta, “To prune, or not to prune: exploring the efficacy of pruning for model compression”, *2017 NIPS Workshop on Machine Learning of Phones and other Consumer Devices* (https://arxiv.org/pdf/1710.01878.pdf) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 39eb79daf0..d16af9da19 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -72,6 +72,7 @@ from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables @@ -130,6 +131,23 @@ def _weight_threshold_variable(var, scope): return threshold +def _kronecker_product(mat1, mat2): + """Computes the Kronecker product of two matrices mat1 and mat2. + + Args: + mat1: A matrix of size m x n + mat2: A matrix of size p x q + Returns: + Kronecker product of matrices mat1 and mat2 of size mp x nq + """ + + m1, n1 = mat1.get_shape().as_list() + mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) + m2, n2 = mat2.get_shape().as_list() + mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) + return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) + + def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None): """Return histogram of values. @@ -298,6 +316,13 @@ def get_pruning_hparams(): How often should the masks be updated? (in # of global_steps) nbins: integer number of bins to use for histogram computation + block_height: integer + number of rows in a block (defaults to 1) + block_width: integer + number of cols in a block (defaults to 1) + block_pooling_function: string + Whether to perform average (AVG) or max (MAX) pooling in the block + (default: AVG) initial_sparsity: float initial sparsity value target_sparsity: float @@ -333,6 +358,9 @@ def get_pruning_hparams(): threshold_decay=0.9, pruning_frequency=10, nbins=255, + block_height=1, + block_width=1, + block_pooling_function='AVG', initial_sparsity=0, target_sparsity=0.5, sparsity_function_begin_step=0, @@ -375,6 +403,12 @@ class Pruning(object): # were updated self._last_update_step = self._setup_last_update_step() + # Block dimensions + self._block_dim = [self._spec.block_height, self._spec.block_width] + + # Block pooling function + self._block_pooling_function = self._spec.block_pooling_function + def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: @@ -449,9 +483,10 @@ class Pruning(object): Returns: new_threshold: The new value of the threshold based on weights, and - desired_sparsity - new_mask: A n-D numpy array containing 0 or 1 to indicate which of the - values in weights falls below the threshold + sparsity at the current global_step + new_mask: A numpy array of the same size and shape as weights containing + 0 or 1 to indicate which of the values in weights falls below + the threshold Raises: ValueError: if sparsity is not defined @@ -484,6 +519,63 @@ class Pruning(object): math_ops.greater(abs_weights, smoothed_threshold), np.float32) return smoothed_threshold, new_mask + def _maybe_update_block_mask(self, weights, threshold): + """Performs block-granular masking of the weights. + + Block pruning occurs only if the block_height or block_width is > 1 and + if the weight tensor has ndims = 2. Otherwise, elementwise pruning occurs. + Args: + weights: The weight tensor that needs to be masked. + threshold: The current threshold value. The function will compute a new + threshold and return the exponential moving average using the current + value of threshold + + Returns: + new_threshold: The new value of the threshold based on weights, and + sparsity at the current global_step + new_mask: A numpy array of the same size and shape as weights containing + 0 or 1 to indicate which of the values in weights falls below + the threshold + + Raises: + ValueError: if block pooling function is not AVG or MAX + """ + if weights.get_shape().ndims != 2 or self._block_dim == [1, 1]: + return self._update_mask(weights, threshold) + + if self._block_pooling_function not in ['AVG', 'MAX']: + raise ValueError('Unknown pooling function for block sparsity: %s' % + self._block_pooling_function) + + with ops.name_scope(weights.op.name + '_pruning_ops'): + abs_weights = math_ops.abs( + array_ops.reshape( + weights, [1, weights.get_shape()[0], + weights.get_shape()[1], 1])) + pool_window = [self._block_dim[0], self._block_dim[1]] + pooled_weights = nn_ops.pool( + abs_weights, + window_shape=pool_window, + pooling_type=self._block_pooling_function, + strides=pool_window, + padding='SAME', + name=weights.op.name + '_pooled') + + smoothed_threshold, new_mask = self._update_mask(pooled_weights, + threshold) + + reshaped_mask = array_ops.reshape( + new_mask, + [pooled_weights.get_shape()[1], + pooled_weights.get_shape()[2]]) + updated_mask = _kronecker_product(reshaped_mask, + array_ops.ones(self._block_dim)) + sliced_mask = array_ops.slice( + updated_mask, [0, 0], + [weights.get_shape()[0], + weights.get_shape()[1]]) + return smoothed_threshold, sliced_mask + def _get_mask_assign_ops(self): # Make sure the assignment ops have not already been added to the list if self._assign_ops: @@ -510,7 +602,7 @@ class Pruning(object): if self._exists_in_do_not_prune_list(mask.name): continue - new_threshold, new_mask = self._update_mask(weight, threshold) + new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) self._assign_ops.append(_variable_assign(threshold, new_threshold)) self._assign_ops.append( diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 34b4584f49..1767b4bb94 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -21,6 +21,7 @@ from __future__ import print_function import numpy as np from tensorflow.contrib.model_pruning.python import pruning +from tensorflow.python.framework import constant_op from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import random_ops @@ -111,6 +112,39 @@ class PruningTest(test.TestCase): masked_weights_val = masked_weights.eval() self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + def _blockMasking(self, hparams, weights, expected_mask): + + threshold = variables.Variable(0.0, name="threshold") + sparsity = variables.Variable(0.51, name="sparsity") + test_spec = ",".join(hparams) + pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) + + # Set up pruning + p = pruning.Pruning(pruning_hparams, sparsity=sparsity) + with self.test_session(): + variables.global_variables_initializer().run() + _, new_mask = p._maybe_update_block_mask(weights, threshold) + # Check if the mask is the same size as the weights + self.assertAllEqual(new_mask.get_shape(), weights.get_shape()) + mask_val = new_mask.eval() + self.assertAllEqual(mask_val, expected_mask) + + def testBlockMasking(self): + param_list = ["block_height=2", "block_width=2", "threshold_decay=0"] + + weights_avg = constant_op.constant( + [[0.1, 0.1, 0.2, 0.2], [0.1, 0.1, 0.2, 0.2], [0.3, 0.3, 0.4, 0.4], + [0.3, 0.3, 0.4, 0.4]]) + weights_max = constant_op.constant( + [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], + [0.0, -0.3, 0.0, -0.4]]) + expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] + + self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, + expected_mask) + self._blockMasking(param_list + ["block_pooling_function=AVG"], + weights_avg, expected_mask) + def testPartitionedVariableMasking(self): partitioner = partitioned_variables.variable_axis_size_partitioner(40) with self.test_session() as session: |