diff options
author | Suyog Gupta <suyoggupta@google.com> | 2018-08-08 14:30:30 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-08 14:34:58 -0700 |
commit | 2b4fd1c2b7a37367e61bbae3d27d194a894cb7bb (patch) | |
tree | 17c3c9b3ee1819727e0d37016f9f5d1e6f4f4780 /tensorflow/contrib/model_pruning | |
parent | 3d1661826ec668d717122a88463ab9b1c1e6f7ae (diff) |
Add helper function for validating the user-provided pruning hparams
PiperOrigin-RevId: 207946581
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/README.md | 2 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning.py | 36 |
2 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 0761dea900..a5267fd904 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -5,7 +5,7 @@ neural network's weight tensors. The API helps inject necessary tensorflow op into the training graph so the model can be pruned while it is being trained. ## Table of contents -1. [Model creation](# model-creation) +1. [Model creation](#model-creation) 2. [Hyperparameters for pruning](#hyperparameters) - [Block sparsity](#block-sparsity) 3. [Adding pruning ops to the training graph](#adding-pruning-ops) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 723dab9369..cd58526ed3 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -237,6 +237,9 @@ class Pruning(object): # Pruning specification self._spec = spec if spec else get_pruning_hparams() + # Sanity check for pruning hparams + self._validate_spec() + # A tensorflow variable that tracks the sparsity function. # If not provided as input, the graph must already contain the global_step # variable before calling this constructor. @@ -262,6 +265,34 @@ class Pruning(object): # Mapping of weight names and target sparsity self._weight_sparsity_map = self._get_weight_sparsity_map() + def _validate_spec(self): + spec = self._spec + if spec.begin_pruning_step < 0: + raise ValueError('Illegal value for begin_pruning_step') + + if spec.begin_pruning_step >= spec.end_pruning_step: + if spec.end_pruning_step != -1: + raise ValueError( + 'Pruning must begin before it can end. begin_step=%d, end_step=%d.' + 'Set end_pruning_step to -1 if pruning is required till training' + 'stops' % (spec.begin_pruning_step, spec.end_pruning_step)) + + if spec.sparsity_function_begin_step < 0: + raise ValueError('Illegal value for sparsity_function_begin_step') + + if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step: + raise ValueError( + 'Sparsity function requires begin_step < end_step') + + if not 0.0 <= spec.threshold_decay < 1.0: + raise ValueError('threshold_decay must be in range [0,1)') + + if not 0.0 <= spec.initial_sparsity < 1.0: + raise ValueError('initial_sparsity must be in range [0,1)') + + if not 0.0 <= spec.target_sparsity < 1.0: + raise ValueError('target_sparsity must be in range [0,1)') + def _setup_global_step(self, global_step): graph_global_step = global_step if graph_global_step is None: @@ -276,11 +307,6 @@ class Pruning(object): target_sparsity = self._spec.target_sparsity exponent = self._spec.sparsity_function_exponent - if begin_step >= end_step: - raise ValueError( - 'Pruning must begin before it can end. begin_step=%d, end_step=%d' % - (begin_step, end_step)) - with ops.name_scope(self._spec.name): p = math_ops.minimum( 1.0, |