aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Suyog Gupta <suyoggupta@google.com>2018-08-27 17:30:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-27 17:36:57 -0700
commit4a596512be206b28120acd3253d022042fa2ce6d (patch)
treee564b2468510b5e036d905f468eb1ff20687f8ef /tensorflow/contrib/model_pruning
parentc7173ca08a06145439362280517bd1e741ee8c7b (diff)
Use nbins as given in hparams when pruning on TPUs
PiperOrigin-RevId: 210461150
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py9
2 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index a5267fd904..15d95896d9 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -53,7 +53,7 @@ The pruning library allows for specification of the following hyper parameters:
| weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
-| nbins | integer | 256 | Number of bins to use for histogram computation |
+| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. |
| block_height|integer | 1 | Number of rows in a block for block sparse matrices|
| block_width |integer | 1 | Number of cols in a block for block sparse matrices|
| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index b50a372e9d..91b0bb7f60 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -235,19 +235,18 @@ def compute_cdf_from_histogram(values, value_range, **kwargs):
def compute_cdf(values, value_range, **kwargs):
"""Returns the normalized cumulative distribution of the given values tensor.
- Uses tf.while_loop to directly compute the cdf of the values. Number of bins
- for histogram is fixed at _NBINS=255
+ Uses tf.while_loop to directly compute the cdf of the values.
Args:
values: Numeric `Tensor`.
value_range: Shape [2] `Tensor` of same `dtype` as `values`
- **kwargs: keyword arguments: name
+ **kwargs: keyword arguments: nbins, name
Returns:
A 1-D `Tensor` holding normalized cdf of values.
"""
- nbins = _NBINS
+ nbins = kwargs.get('nbins', _NBINS)
name = kwargs.get('name', None)
with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
values = ops.convert_to_tensor(values, name='values')
@@ -281,7 +280,7 @@ def compute_cdf(values, value_range, **kwargs):
cdf = math_ops.add(
cdf,
array_ops.one_hot(
- loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
+ loop_count, depth=nbins, on_value=temp, off_value=0.0))
return [loop_count + 1, cdf]
_, cdf = control_flow_ops.while_loop(