diff options
author | Suyog Gupta <suyoggupta@google.com> | 2018-08-27 17:30:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-27 17:36:57 -0700 |
commit | 4a596512be206b28120acd3253d022042fa2ce6d (patch) | |
tree | e564b2468510b5e036d905f468eb1ff20687f8ef /tensorflow/contrib/model_pruning | |
parent | c7173ca08a06145439362280517bd1e741ee8c7b (diff) |
Use nbins as given in hparams when pruning on TPUs
PiperOrigin-RevId: 210461150
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/README.md | 2 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_utils.py | 9 |
2 files changed, 5 insertions, 6 deletions
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index a5267fd904..15d95896d9 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -53,7 +53,7 @@ The pruning library allows for specification of the following hyper parameters: | weight_sparsity_map | list of strings | [""] | list of weight variable name (or layer name):target sparsity pairs. Eg. [conv1:0.9,conv2/kernel:0.8]. For layers/weights not in this list, sparsity as specified by the target_sparsity hyperparameter is used. | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 256 | Number of bins to use for histogram computation | +| nbins | integer | 256 | Number of bins to use for histogram computation. Note: When running on TPUs, a large (>1024) value for `nbins` may adversely affect the training time. | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index b50a372e9d..91b0bb7f60 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -235,19 +235,18 @@ def compute_cdf_from_histogram(values, value_range, **kwargs): def compute_cdf(values, value_range, **kwargs): """Returns the normalized cumulative distribution of the given values tensor. - Uses tf.while_loop to directly compute the cdf of the values. Number of bins - for histogram is fixed at _NBINS=255 + Uses tf.while_loop to directly compute the cdf of the values. Args: values: Numeric `Tensor`. value_range: Shape [2] `Tensor` of same `dtype` as `values` - **kwargs: keyword arguments: name + **kwargs: keyword arguments: nbins, name Returns: A 1-D `Tensor` holding normalized cdf of values. """ - nbins = _NBINS + nbins = kwargs.get('nbins', _NBINS) name = kwargs.get('name', None) with ops.name_scope(name, 'cdf', [values, value_range, nbins]): values = ops.convert_to_tensor(values, name='values') @@ -281,7 +280,7 @@ def compute_cdf(values, value_range, **kwargs): cdf = math_ops.add( cdf, array_ops.one_hot( - loop_count, depth=_NBINS, on_value=temp, off_value=0.0)) + loop_count, depth=nbins, on_value=temp, off_value=0.0)) return [loop_count + 1, cdf] _, cdf = control_flow_ops.while_loop( |