aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Suyog Gupta <suyoggupta@google.com>2018-08-08 14:30:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-08 14:34:58 -0700
commit2b4fd1c2b7a37367e61bbae3d27d194a894cb7bb (patch)
tree17c3c9b3ee1819727e0d37016f9f5d1e6f4f4780 /tensorflow/contrib/model_pruning
parent3d1661826ec668d717122a88463ab9b1c1e6f7ae (diff)
Add helper function for validating the user-provided pruning hparams
PiperOrigin-RevId: 207946581
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py36
2 files changed, 32 insertions, 6 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 0761dea900..a5267fd904 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -5,7 +5,7 @@ neural network's weight tensors. The API helps inject necessary tensorflow op
into the training graph so the model can be pruned while it is being trained.
## Table of contents
-1. [Model creation](# model-creation)
+1. [Model creation](#model-creation)
2. [Hyperparameters for pruning](#hyperparameters)
- [Block sparsity](#block-sparsity)
3. [Adding pruning ops to the training graph](#adding-pruning-ops)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 723dab9369..cd58526ed3 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -237,6 +237,9 @@ class Pruning(object):
# Pruning specification
self._spec = spec if spec else get_pruning_hparams()
+ # Sanity check for pruning hparams
+ self._validate_spec()
+
# A tensorflow variable that tracks the sparsity function.
# If not provided as input, the graph must already contain the global_step
# variable before calling this constructor.
@@ -262,6 +265,34 @@ class Pruning(object):
# Mapping of weight names and target sparsity
self._weight_sparsity_map = self._get_weight_sparsity_map()
+ def _validate_spec(self):
+ spec = self._spec
+ if spec.begin_pruning_step < 0:
+ raise ValueError('Illegal value for begin_pruning_step')
+
+ if spec.begin_pruning_step >= spec.end_pruning_step:
+ if spec.end_pruning_step != -1:
+ raise ValueError(
+ 'Pruning must begin before it can end. begin_step=%d, end_step=%d.'
+ 'Set end_pruning_step to -1 if pruning is required till training'
+ 'stops' % (spec.begin_pruning_step, spec.end_pruning_step))
+
+ if spec.sparsity_function_begin_step < 0:
+ raise ValueError('Illegal value for sparsity_function_begin_step')
+
+ if spec.sparsity_function_begin_step >= spec.sparsity_function_end_step:
+ raise ValueError(
+ 'Sparsity function requires begin_step < end_step')
+
+ if not 0.0 <= spec.threshold_decay < 1.0:
+ raise ValueError('threshold_decay must be in range [0,1)')
+
+ if not 0.0 <= spec.initial_sparsity < 1.0:
+ raise ValueError('initial_sparsity must be in range [0,1)')
+
+ if not 0.0 <= spec.target_sparsity < 1.0:
+ raise ValueError('target_sparsity must be in range [0,1)')
+
def _setup_global_step(self, global_step):
graph_global_step = global_step
if graph_global_step is None:
@@ -276,11 +307,6 @@ class Pruning(object):
target_sparsity = self._spec.target_sparsity
exponent = self._spec.sparsity_function_exponent
- if begin_step >= end_step:
- raise ValueError(
- 'Pruning must begin before it can end. begin_step=%d, end_step=%d' %
- (begin_step, end_step))
-
with ops.name_scope(self._spec.name):
p = math_ops.minimum(
1.0,