aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Suyog Gupta <suyoggupta@google.com>2018-07-31 16:48:46 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-31 16:53:21 -0700
commitc87a0db716c906fcd698a599680608bc5085e56b (patch)
tree5464b35487527fedfcaeb82d0ebe51b1a91af551 /tensorflow/contrib/model_pruning
parent2bf582914d09207ad7276e2f471ea9776415e8e0 (diff)
Add support for layer-dependent sparsity. Accept layer_name:target_sparsity mapping as hyperparameter
Deprecate do_not_prune hyperparameter PiperOrigin-RevId: 206851318
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py72
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py39
3 files changed, 82 insertions, 31 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 9143d082bf..dbe4e124fd 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -42,7 +42,7 @@ The pruning library allows for specification of the following hyper parameters:
| name | string | model_pruning | Name of the pruning specification. Used for adding summaries and ops under a common tensorflow name_scope |
| begin_pruning_step | integer | 0 | The global step at which to begin pruning |
| end_pruning_step | integer | -1 | The global step at which to terminate pruning. Defaults to -1 implying that pruning continues till the training stops |
-| do_not_prune | list of strings | [""] | list of layers names that are not pruned |
+| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
| nbins | integer | 256 | Number of bins to use for histogram computation |
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index da9d398cbc..723dab9369 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -152,8 +152,11 @@ def get_pruning_hparams():
end_pruning_step: integer
the global step at which to terminate pruning. Defaults to -1 implying
that pruning continues till the training stops
- do_not_prune: list of strings
- list of layers that are not pruned
+ weight_sparsity_map: list of strings
+ comma separed list of weight variable name:target sparsity pairs.
+ For layers/weights not in this list, sparsity as specified by the
+ target_sparsity hyperparameter is used.
+ Eg. [conv1:0.9,conv2/kernel:0.8]
threshold_decay: float
the decay factor to use for exponential decay of the thresholds
pruning_frequency: integer
@@ -200,7 +203,7 @@ def get_pruning_hparams():
name='model_pruning',
begin_pruning_step=0,
end_pruning_step=-1,
- do_not_prune=[''],
+ weight_sparsity_map=[''],
threshold_decay=0.9,
pruning_frequency=10,
nbins=256,
@@ -256,6 +259,9 @@ class Pruning(object):
# Block pooling function
self._block_pooling_function = self._spec.block_pooling_function
+ # Mapping of weight names and target sparsity
+ self._weight_sparsity_map = self._get_weight_sparsity_map()
+
def _setup_global_step(self, global_step):
graph_global_step = global_step
if graph_global_step is None:
@@ -306,15 +312,36 @@ class Pruning(object):
'last_mask_update_step', dtype=dtypes.int32)
return last_update_step
- def _exists_in_do_not_prune_list(self, tensor_name):
- do_not_prune_list = self._spec.do_not_prune
- if not do_not_prune_list[0]:
- return False
- for layer_name in do_not_prune_list:
- if tensor_name.find(layer_name) != -1:
- return True
-
- return False
+ def _get_weight_sparsity_map(self):
+ """Return the map of weight_name:sparsity parsed from the hparams."""
+ weight_sparsity_map = {}
+ val_list = self._spec.weight_sparsity_map
+ filtered_val_list = [l for l in val_list if l]
+ for val in filtered_val_list:
+ weight_name, sparsity = val.split(':')
+ if float(sparsity) >= 1.0:
+ raise ValueError('Weight sparsity can not exceed 1.0')
+ weight_sparsity_map[weight_name] = float(sparsity)
+
+ return weight_sparsity_map
+
+ def _get_sparsity(self, weight_name):
+ """Return target sparsity for the given layer/weight name."""
+ target_sparsity = [
+ sparsity for name, sparsity in self._weight_sparsity_map.items()
+ if weight_name.find(name) != -1
+ ]
+ if not target_sparsity:
+ return self._sparsity
+
+ if len(target_sparsity) > 1:
+ raise ValueError(
+ 'Multiple matches in weight_sparsity_map for weight %s' % weight_name)
+ # TODO(suyoggupta): This will work when initial_sparsity = 0. Generalize
+ # to handle other cases as well.
+ return math_ops.mul(
+ self._sparsity,
+ math_ops.div(target_sparsity[0], self._spec.target_sparsity))
def _update_mask(self, weights, threshold):
"""Updates the mask for a given weight tensor.
@@ -342,6 +369,8 @@ class Pruning(object):
if self._sparsity is None:
raise ValueError('Sparsity variable undefined')
+ sparsity = self._get_sparsity(weights.op.name)
+
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(weights)
max_value = math_ops.reduce_max(abs_weights)
@@ -354,7 +383,7 @@ class Pruning(object):
math_ops.div(
math_ops.reduce_sum(
math_ops.cast(
- math_ops.less(norm_cdf, self._sparsity), dtypes.float32)),
+ math_ops.less(norm_cdf, sparsity), dtypes.float32)),
float(self._spec.nbins)), max_value)
smoothed_threshold = math_ops.add_n([
@@ -453,10 +482,6 @@ class Pruning(object):
if is_partitioned:
weight = weight.as_tensor()
- if self._spec.do_not_prune:
- if self._exists_in_do_not_prune_list(mask.name):
- continue
-
new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
self._assign_ops.append(
pruning_utils.variable_assign(threshold, new_threshold))
@@ -507,22 +532,15 @@ class Pruning(object):
no_update_op)
def add_pruning_summaries(self):
- """Adds summaries for this pruning spec.
-
- Args: none
-
- Returns: none
- """
+ """Adds summaries of weight sparsities and thresholds."""
with ops.name_scope(self._spec.name + '_summaries'):
summary.scalar('sparsity', self._sparsity)
summary.scalar('last_mask_update_step', self._last_update_step)
masks = get_masks()
thresholds = get_thresholds()
for mask, threshold in zip(masks, thresholds):
- if not self._exists_in_do_not_prune_list(mask.name):
- summary.scalar(mask.op.name + '/sparsity',
- nn_impl.zero_fraction(mask))
- summary.scalar(threshold.op.name + '/threshold', threshold)
+ summary.scalar(mask.op.name + '/sparsity', nn_impl.zero_fraction(mask))
+ summary.scalar(threshold.op.name + '/threshold', threshold)
def print_hparams(self):
logging.info(self._spec.to_json())
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index f80b7c52c0..5b67656e9f 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -35,8 +35,8 @@ from tensorflow.python.training import training_util
class PruningHParamsTest(test.TestCase):
PARAM_LIST = [
"name=test", "threshold_decay=0.9", "pruning_frequency=10",
- "do_not_prune=[conv1,conv2]", "sparsity_function_end_step=100",
- "target_sparsity=0.9"
+ "sparsity_function_end_step=100", "target_sparsity=0.9",
+ "weight_sparsity_map=[conv1:0.8,conv2/kernel:0.8]"
]
TEST_HPARAMS = ",".join(PARAM_LIST)
@@ -55,9 +55,11 @@ class PruningHParamsTest(test.TestCase):
self.assertEqual(p._spec.name, "test")
self.assertAlmostEqual(p._spec.threshold_decay, 0.9)
self.assertEqual(p._spec.pruning_frequency, 10)
- self.assertAllEqual(p._spec.do_not_prune, ["conv1", "conv2"])
self.assertEqual(p._spec.sparsity_function_end_step, 100)
self.assertAlmostEqual(p._spec.target_sparsity, 0.9)
+ self.assertEqual(p._weight_sparsity_map["conv1"], 0.8)
+ self.assertEqual(p._weight_sparsity_map["conv2/kernel"], 0.8)
+
def testInitWithExternalSparsity(self):
with self.test_session():
@@ -211,6 +213,37 @@ class PruningTest(test.TestCase):
expected_non_zero_count = [100, 100, 80, 80, 60, 60, 40, 40, 40, 40]
self.assertAllEqual(expected_non_zero_count, non_zero_count)
+ def testWeightSpecificSparsity(self):
+ param_list = [
+ "begin_pruning_step=1", "pruning_frequency=1", "end_pruning_step=100",
+ "target_sparsity=0.5", "weight_sparsity_map=[layer2/weights:0.75]",
+ "threshold_decay=0.0"
+ ]
+ test_spec = ",".join(param_list)
+ pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
+
+ with variable_scope.variable_scope("layer1"):
+ w1 = variables.Variable(
+ math_ops.linspace(1.0, 100.0, 100), name="weights")
+ _ = pruning.apply_mask(w1)
+ with variable_scope.variable_scope("layer2"):
+ w2 = variables.Variable(
+ math_ops.linspace(1.0, 100.0, 100), name="weights")
+ _ = pruning.apply_mask(w2)
+
+ p = pruning.Pruning(pruning_hparams)
+ mask_update_op = p.conditional_mask_update_op()
+ increment_global_step = state_ops.assign_add(self.global_step, 1)
+
+ with self.test_session() as session:
+ variables.global_variables_initializer().run()
+ for _ in range(110):
+ session.run(mask_update_op)
+ session.run(increment_global_step)
+
+ self.assertAllEqual(
+ session.run(pruning.get_weight_sparsity()), [0.5, 0.75])
+
if __name__ == "__main__":
test.main()