aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning/python/learning.py
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-07-25 01:08:01 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-07-25 01:08:01 -0700
commit1cdacb8b10d0b4687387be5fd8be978d68602a1d (patch)
treea2bf88798854a426f073325eb85d85b3ab914418 /tensorflow/contrib/model_pruning/python/learning.py
parentf88a6f93bee89c610fa8b399d037c7a33c1a0a3e (diff)
parent3f454e4060d855f43eebe0cdc27d8c24f906d430 (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/contrib/model_pruning/python/learning.py')
-rw-r--r--tensorflow/contrib/model_pruning/python/learning.py11
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py
index 2b79c23cef..26695237c2 100644
--- a/tensorflow/contrib/model_pruning/python/learning.py
+++ b/tensorflow/contrib/model_pruning/python/learning.py
@@ -33,11 +33,14 @@ to support training of pruned models
# Create the train_op
train_op = slim.learning.create_train_op(total_loss, optimizer)
- # Set up sparsity
- sparsity = pruning.setup_gradual_sparsity(self.global_step)
+ # Parse pruning hyperparameters
+ pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
- # Create mask update op
- mask_update_op = pruning.add_mask_update_ip(sparsity)
+ # Create a pruning object using the pruning_hparams
+ p = pruning.Pruning(pruning_hparams)
+
+ # Add mask update ops to the graph
+ mask_update_op = p.conditional_mask_update_op()
# Run training.
learning.train(train_op,