aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-12-03 17:50:50 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-12-03 17:57:43 -0800
commita22c8219581a6bd597c92b51c5dbe7db706e3100 (patch)
tree68509358bc117ead0034ab6054a1fbe854790d00 /tensorflow/contrib/model_pruning
parent2c385166a4c9eb16d7aca2a7335e96569c59d124 (diff)
Update the pruning library to handle graphs which has both partitioned and non-partitioned variables
PiperOrigin-RevId: 177761638
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py23
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py2
2 files changed, 10 insertions, 15 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 42d91a71fd..39eb79daf0 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -74,6 +74,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.summary import summary
from tensorflow.python.training import training_util
@@ -341,11 +342,7 @@ def get_pruning_hparams():
class Pruning(object):
- def __init__(self,
- spec=None,
- global_step=None,
- sparsity=None,
- partitioner=None):
+ def __init__(self, spec=None, global_step=None, sparsity=None):
"""Set up the specification for model pruning.
If a spec is provided, the sparsity is set up based on the sparsity_function
@@ -358,8 +355,6 @@ class Pruning(object):
global_step: A tensorflow variable that is used while setting up the
sparsity function
sparsity: A tensorflow scalar variable storing the sparsity
- partitioner: The tensorflow partitioner function used to distribute
- parameters across shards
"""
# Pruning specification
self._spec = spec if spec else get_pruning_hparams()
@@ -373,9 +368,6 @@ class Pruning(object):
# Built using self._setup_sparsity() or provided externally
self._sparsity = sparsity if sparsity else self._setup_sparsity()
- # Stores the partitioner function uses to partition variables across tasks/
- self._partitioner = partitioner
-
# List of tensorflow assignments ops for new masks and thresholds
self._assign_ops = []
@@ -509,8 +501,10 @@ class Pruning(object):
for index, mask in enumerate(masks):
threshold = thresholds[index]
- weight = weights[index] if self._partitioner is None else weights[
- index].as_tensor()
+ weight = weights[index]
+ is_partitioned = isinstance(weight, variables.PartitionedVariable)
+ if is_partitioned:
+ weight = weight.as_tensor()
if self._spec.do_not_prune:
if self._exists_in_do_not_prune_list(mask.name):
@@ -518,9 +512,10 @@ class Pruning(object):
new_threshold, new_mask = self._update_mask(weight, threshold)
self._assign_ops.append(_variable_assign(threshold, new_threshold))
+
self._assign_ops.append(
- _variable_assign(mask, new_mask) if self._partitioner is None else
- _partitioned_variable_assign(mask, new_mask))
+ _partitioned_variable_assign(mask, new_mask)
+ if is_partitioned else _variable_assign(mask, new_mask))
def mask_update_op(self):
with ops.name_scope(self._spec.name):
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index c23fd649ce..34b4584f49 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -120,7 +120,7 @@ class PruningTest(test.TestCase):
"weights", initializer=math_ops.linspace(1.0, 100.0, 100))
masked_weights = pruning.apply_mask(
weights, scope=variable_scope.get_variable_scope())
- p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner)
+ p = pruning.Pruning(sparsity=sparsity)
p._spec.threshold_decay = 0.0
mask_update_op = p.mask_update_op()
variables.global_variables_initializer().run()