aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 11:01:42 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-19 11:01:42 -0700
commit9fa89160a422433b295e22edda3776bcbc309e45 (patch)
treee9d74fcc08140ce5bc0ff784467d12b9f9dda62b /tensorflow/contrib/model_pruning
parent1044888430b34353f54266bf0674144dfe675687 (diff)
parent11cd70438e7d7104904bf8f3b24fcaf6fd88eab5 (diff)
Merge pull request #19779 from Huizerd:master
PiperOrigin-RevId: 205266716
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md11
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py8
2 files changed, 10 insertions, 9 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 86f4fd6adf..9143d082bf 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -66,10 +66,10 @@ is the sparsity_function_begin_step. In this equation, the
sparsity_function_exponent is set to 3.
### Adding pruning ops to the training graph
-The final step involves adding ops to the training graph that monitors the
-distribution of the layer's weight magnitudes and determines the layer threshold
-such masking all the weights below this threshold achieves the sparsity level
-desired for the current training step. This can be achieved as follows:
+The final step involves adding ops to the training graph that monitor the
+distribution of the layer's weight magnitudes and determine the layer threshold,
+such that masking all the weights below this threshold achieves the sparsity
+level desired for the current training step. This can be achieved as follows:
```python
tf.app.flags.DEFINE_string(
@@ -79,7 +79,7 @@ tf.app.flags.DEFINE_string(
with tf.graph.as_default():
# Create global step variable
- global_step = tf.train.get_global_step()
+ global_step = tf.train.get_or_create_global_step()
# Parse pruning hyperparameters
pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
@@ -103,6 +103,7 @@ with tf.graph.as_default():
mon_sess.run(mask_update_op)
```
+Ensure that `global_step` is being [incremented](https://www.tensorflow.org/api_docs/python/tf/train/Optimizer#minimize), otherwise pruning will not work!
## Example: Pruning and training deep CNNs on the cifar10 dataset
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 4b7af18b33..da9d398cbc 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -518,11 +518,11 @@ class Pruning(object):
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
- for index, mask in enumerate(masks):
+ for mask, threshold in zip(masks, thresholds):
if not self._exists_in_do_not_prune_list(mask.name):
- summary.scalar(mask.name + '/sparsity', nn_impl.zero_fraction(mask))
- summary.scalar(thresholds[index].op.name + '/threshold',
- thresholds[index])
+ summary.scalar(mask.op.name + '/sparsity',
+ nn_impl.zero_fraction(mask))
+ summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())