aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-07-17 22:29:31 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-17 22:32:34 -0700
commitaa15692e54390cf3967d51bc60acf5f783df9c08 (patch)
treefaadaa4781271cbbc927b1bb2a483624e17321cd /tensorflow/contrib/model_pruning
parent81161f9d9987a8eb70793d95048c20be34292859 (diff)
Update documentation for using pruning and contrib/slim training utility
PiperOrigin-RevId: 205027982
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/learning.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py
index 2b79c23cef..26695237c2 100644
--- a/tensorflow/contrib/model_pruning/python/learning.py
+++ b/tensorflow/contrib/model_pruning/python/learning.py
@@ -33,11 +33,14 @@ to support training of pruned models
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
- # Set up sparsity
- sparsity = pruning.setup_gradual_sparsity(self.global_step)
+ # Parse pruning hyperparameters
+ pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
- # Create mask update op
- mask_update_op = pruning.add_mask_update_ip(sparsity)
+ # Create a pruning object using the pruning_hparams
+ p = pruning.Pruning(pruning_hparams)
+
+ # Add mask update ops to the graph
+ mask_update_op = p.conditional_mask_update_op()
# Run training.
learning.train(train_op,