diff options
Diffstat (limited to 'tensorflow/contrib/model_pruning/python')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning.py | 23 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_test.py | 2 |
2 files changed, 10 insertions, 15 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 42d91a71fd..39eb79daf0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -74,6 +74,7 @@ from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl from tensorflow.python.ops import state_ops from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables from tensorflow.python.platform import tf_logging as logging from tensorflow.python.summary import summary from tensorflow.python.training import training_util @@ -341,11 +342,7 @@ def get_pruning_hparams(): class Pruning(object): - def __init__(self, - spec=None, - global_step=None, - sparsity=None, - partitioner=None): + def __init__(self, spec=None, global_step=None, sparsity=None): """Set up the specification for model pruning. If a spec is provided, the sparsity is set up based on the sparsity_function @@ -358,8 +355,6 @@ class Pruning(object): global_step: A tensorflow variable that is used while setting up the sparsity function sparsity: A tensorflow scalar variable storing the sparsity - partitioner: The tensorflow partitioner function used to distribute - parameters across shards """ # Pruning specification self._spec = spec if spec else get_pruning_hparams() @@ -373,9 +368,6 @@ class Pruning(object): # Built using self._setup_sparsity() or provided externally self._sparsity = sparsity if sparsity else self._setup_sparsity() - # Stores the partitioner function uses to partition variables across tasks/ - self._partitioner = partitioner - # List of tensorflow assignments ops for new masks and thresholds self._assign_ops = [] @@ -509,8 +501,10 @@ class Pruning(object): for index, mask in enumerate(masks): threshold = thresholds[index] - weight = weights[index] if self._partitioner is None else weights[ - index].as_tensor() + weight = weights[index] + is_partitioned = isinstance(weight, variables.PartitionedVariable) + if is_partitioned: + weight = weight.as_tensor() if self._spec.do_not_prune: if self._exists_in_do_not_prune_list(mask.name): @@ -518,9 +512,10 @@ class Pruning(object): new_threshold, new_mask = self._update_mask(weight, threshold) self._assign_ops.append(_variable_assign(threshold, new_threshold)) + self._assign_ops.append( - _variable_assign(mask, new_mask) if self._partitioner is None else - _partitioned_variable_assign(mask, new_mask)) + _partitioned_variable_assign(mask, new_mask) + if is_partitioned else _variable_assign(mask, new_mask)) def mask_update_op(self): with ops.name_scope(self._spec.name): diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index c23fd649ce..34b4584f49 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -120,7 +120,7 @@ class PruningTest(test.TestCase): "weights", initializer=math_ops.linspace(1.0, 100.0, 100)) masked_weights = pruning.apply_mask( weights, scope=variable_scope.get_variable_scope()) - p = pruning.Pruning(sparsity=sparsity, partitioner=partitioner) + p = pruning.Pruning(sparsity=sparsity) p._spec.threshold_decay = 0.0 mask_update_op = p.mask_update_op() variables.global_variables_initializer().run() |