aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning/README.md
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/model_pruning/README.md')
-rw-r--r--tensorflow/contrib/model_pruning/README.md46
1 files changed, 38 insertions, 8 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index dbe4e124fd..a5267fd904 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -4,7 +4,15 @@ This document describes the API that facilitates magnitude-based pruning of
neural network's weight tensors. The API helps inject necessary tensorflow op
into the training graph so the model can be pruned while it is being trained.
-### Model creation
+## Table of contents
+1. [Model creation](#model-creation)
+2. [Hyperparameters for pruning](#hyperparameters)
+ - [Block sparsity](#block-sparsity)
+3. [Adding pruning ops to the training graph](#adding-pruning-ops)
+4. [Removing pruning ops from trained model](#remove)
+5. [Example](#example)
+
+### Model creation <a name="model-creation"></a>
The first step involves adding mask and threshold variables to the layers that
need to undergo pruning. The variable mask is the same shape as the layer's
@@ -33,7 +41,7 @@ auxiliary variables built-in (see
* [rnn_cells.MaskedLSTMCell](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/model_pruning/python/layers/rnn_cells.py?l=154)
-### Adding pruning ops to the training graph
+### Pruning-related hyperparameters <a name="hyperparameters"></a>
The pruning library allows for specification of the following hyper parameters:
@@ -64,7 +72,13 @@ is divided into $$n$$ intervals of size equal to the pruning_frequency ($$\Delta
t$$). $$s_f$$ is the target_sparsity, $$s_i$$ is the initial_sparsity, $$t_0$$
is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
-### Adding pruning ops to the training graph
+
+#### Block Sparsity <a name="block-sparsity"></a>
+
+For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
+The convolution layer tensors are always pruned used block dimensions of [1,1].
+
+### Adding pruning ops to the training graph <a name="adding-pruning-ops"></a>
The final step involves adding ops to the training graph that monitor the
distribution of the layer's weight magnitudes and determine the layer threshold,
@@ -105,7 +119,19 @@ with tf.graph.as_default():
```
Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
-## Example: Pruning and training deep CNNs on the cifar10 dataset
+### Removing pruning ops from the trained graph <a name="remove"></a>
+Once the model is trained, it is necessary to remove the auxiliary variables (mask, threshold) and pruning ops added to the graph in the steps above. This can be accomplished using the `strip_pruning_vars` utility.
+
+This utility generates a binary GraphDef in which the variables have been converted to constants. In particular, the threshold variables are removed from the graph and the mask variable is fused with the corresponding weight tensor to produce a `masked_weight` tensor. This tensor is sparse, has the same size as the weight tensor, and the sparsity is as set by the `target_sparsity` or the `weight_sparsity_map` hyperparameters above.
+
+```shell
+$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
+$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_dir=/path/to/checkpoints/ --output_node_names=graph_node1,graph_node2 --output_dir=/tmp --filename=pruning_stripped.pb
+```
+
+For now, it is assumed that the underlying hardware platform will provide mechanisms for compressing the sparse tensors and/or accelerating the sparse tensor computations.
+
+## Example: Pruning and training deep CNNs on the cifar10 dataset <a name="example"></a>
Please see https://www.tensorflow.org/tutorials/deep_cnn for details on neural
network architecture, setting up inputs etc. The additional changes needed to
@@ -121,7 +147,7 @@ incorporate pruning are captured in the following:
To train the pruned version of cifar10:
-```bash
+```shell
$ examples_dir=contrib/model_pruning/examples
$ bazel build -c opt $examples_dir/cifar10:cifar10_{train,eval}
$ bazel-bin/$examples_dir/cifar10/cifar10_train --pruning_hparams=name=cifar10_pruning,begin_pruning_step=10000,end_pruning_step=100000,target_sparsity=0.9,sparsity_function_begin_step=10000,sparsity_function_end_step=100000
@@ -133,10 +159,14 @@ Eval:
$ bazel-bin/$examples_dir/cifar10/cifar10_eval --run_once
```
-### Block Sparsity
+Removing pruning nodes from the trained graph:
-For some hardware architectures, it may be beneficial to induce spatially correlated sparsity. To train models in which the weight tensors have block sparse structure, set *block_height* and *block_width* hyperparameters to the desired block configuration (2x2, 4x4, 4x1, 1x8, etc). Currently, block sparsity is only supported for weight tensors which can be squeezed to rank 2. The matrix is partitioned into non-overlapping blocks of size *[block_height, block_dim]* and the either the average or max absolute value in this block is taken as a proxy for the entire block (set by *block_pooling_function* hyperparameter).
-The convolution layer tensors are always pruned used block dimensions of [1,1].
+```shell
+$ bazel build -c opt contrib/model_pruning:strip_pruning_vars
+$ bazel-bin/contrib/model_pruning/strip_pruning_vars --checkpoint_path=/tmp/cifar10_train --output_node_names=softmax_linear/softmax_linear_2 --filename=cifar_pruned.pb
+```
+
+The generated GraphDef (cifar_pruned.pb) may be visualized using the [`import_pb_to_tensorboard`](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/tools/import_pb_to_tensorboard.py) utility
## References