diff options
Diffstat (limited to 'tensorflow/contrib/model_pruning/python/learning.py')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/learning.py | 11 |
1 files changed, 7 insertions, 4 deletions
diff --git a/tensorflow/contrib/model_pruning/python/learning.py b/tensorflow/contrib/model_pruning/python/learning.py index 2b79c23cef..26695237c2 100644 --- a/tensorflow/contrib/model_pruning/python/learning.py +++ b/tensorflow/contrib/model_pruning/python/learning.py @@ -33,11 +33,14 @@ to support training of pruned models # Create the train_op train_op = slim.learning.create_train_op(total_loss, optimizer) - # Set up sparsity - sparsity = pruning.setup_gradual_sparsity(self.global_step) + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) - # Create mask update op - mask_update_op = pruning.add_mask_update_ip(sparsity) + # Create a pruning object using the pruning_hparams + p = pruning.Pruning(pruning_hparams) + + # Add mask update ops to the graph + mask_update_op = p.conditional_mask_update_op() # Run training. learning.train(train_op, |