diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-13 00:03:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-13 00:06:16 -0700 |
commit | 73cc1d5b6f95ff56207e4c42b62d383c2427fb75 (patch) | |
tree | 15c5cb772ceb2851ce145df1a2e2782392a6b5c5 /tensorflow/contrib/model_pruning | |
parent | 68f0f1aadb07ed1e7449b969d8807b5f662be33a (diff) |
-- Add a new histogram/cdf computation method compatible with the TPU.
-- Refactor utility functions into pruning_utils.py and add tests
PiperOrigin-RevId: 192727737
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r-- | tensorflow/contrib/model_pruning/BUILD | 24 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/README.md | 2 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning.py | 237 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_test.py | 15 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_utils.py | 269 | ||||
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_utils_test.py | 86 |
6 files changed, 430 insertions, 203 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index f50575b2cf..54bd39afac 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -72,15 +72,37 @@ py_library( ) py_library( + name = "pruning_utils", + srcs = ["python/pruning_utils.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:platform", + "//third_party/py/numpy", + ], +) + +py_library( name = "pruning", srcs = ["python/pruning.py"], srcs_version = "PY2AND3", visibility = ["//visibility:public"], deps = [ ":core_layers", + ":pruning_utils", "//tensorflow/contrib/training:training_py", "//tensorflow/python:platform", - "//third_party/py/numpy", + ], +) + +py_test( + name = "pruning_utils_test", + size = "small", + srcs = ["python/pruning_utils_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":pruning_utils", + "//tensorflow/python:client_testlib", ], ) diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md index 52b659c69f..86f4fd6adf 100644 --- a/tensorflow/contrib/model_pruning/README.md +++ b/tensorflow/contrib/model_pruning/README.md @@ -45,7 +45,7 @@ The pruning library allows for specification of the following hyper parameters: | do_not_prune | list of strings | [""] | list of layers names that are not pruned | | threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds | | pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) | -| nbins | integer | 255 | Number of bins to use for histogram computation | +| nbins | integer | 256 | Number of bins to use for histogram computation | | block_height|integer | 1 | Number of rows in a block for block sparse matrices| | block_width |integer | 1 | Number of cols in a block for block sparse matrices| | block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)| diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index 5146a4a2de..ea6032e588 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -33,12 +33,14 @@ # Returns a list of all the weight tensors that have been masked get_weights() - The Pruning class uses a proto (defined in pruning.proto) to set up the - parameters for a pruning specification. Here's a typical usage: + The Pruning class uses a tf.hparams object to set up the + parameters for a model pruning. Here's a typical usage: - # Initialize a pruning spec from a proto - pruning_spec = '/tmp/pruning.pb' - p = Pruning(pruning_spec) + # Parse pruning hyperparameters + pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams) + + # Create a pruning object using the pruning_hparams + p = pruning.Pruning(pruning_hparams) # Add mask update ops to the graph mask_update_op = p.conditional_mask_update_op() @@ -51,24 +53,20 @@ # An object of the pruning also accepts externally defined sparsity: sparsity = tf.Variable(0.5, name = "ConstantSparsity") - pruning_spec = '/tmp/pruning.pb' - p = Pruning(pruning_spec, sparsity=sparsity) - + p = pruning.Pruning(pruning_hparams, sparsity=sparsity) """ # pylint: disable=missing-docstring from __future__ import absolute_import from __future__ import division from __future__ import print_function -import numpy as np - +from tensorflow.contrib.model_pruning.python import pruning_utils from tensorflow.contrib.model_pruning.python.layers import core_layers as core from tensorflow.contrib.training.python.training import hparam +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops -from tensorflow.python.ops import clip_ops from tensorflow.python.ops import control_flow_ops -from tensorflow.python.ops import gen_array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_impl @@ -87,172 +85,18 @@ _WEIGHT_COLLECTION = core.WEIGHT_COLLECTION _MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME -def _weight_mask_variable(var, scope): - """Create a mask for the weights. - - This function adds a variable 'mask' to the graph. - - Args: - var: the weight variable that needs to be masked - scope: The variable scope of the variable var - - Returns: - the mask variable of the same size and shape as var, initialized to all 1s. - """ - with variable_scope.variable_scope(scope): - mask = variable_scope.get_variable( - 'mask', - var.get_shape(), - initializer=init_ops.ones_initializer(), - trainable=False, - dtype=var.dtype) - return mask - - -def _weight_threshold_variable(var, scope): - """Create a scalar threshold for the weights. - - This function adds a variable - 'threshold' to the graph. - - Args: - var: The weight variable that needs to be masked - scope: The variable scope of the variable var - - Returns: - a scalar threshold variable initialized to 0. - """ - with variable_scope.variable_scope(scope): - threshold = variable_scope.get_variable( - 'threshold', [], - initializer=init_ops.zeros_initializer(), - trainable=False, - dtype=var.dtype) - return threshold - - -def _kronecker_product(mat1, mat2): - """Computes the Kronecker product of two matrices mat1 and mat2. - - Args: - mat1: A matrix of size m x n - mat2: A matrix of size p x q - Returns: - Kronecker product of matrices mat1 and mat2 of size mp x nq - """ - - m1, n1 = mat1.get_shape().as_list() - mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) - m2, n2 = mat2.get_shape().as_list() - mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) - return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) - - -def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None): - """Return histogram of values. - - Given the tensor `values`, this operation returns a rank 1 histogram counting - the number of entries in `values` that fell into every bin. The bins are - equal width and determined by the arguments `value_range` and `nbins`. - - Args: - values: Numeric `Tensor`. - value_range: Shape [2] `Tensor` of same `dtype` as `values`. - values <= value_range[0] will be mapped to hist[0], - values >= value_range[1] will be mapped to hist[-1]. - nbins: Scalar `int32 Tensor`. Number of histogram bins. - dtype: dtype for returned histogram. - name: A name for this operation (defaults to 'histogram'). - - Returns: - A 1-D `Tensor` holding histogram of values. - - """ - with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: - values = ops.convert_to_tensor(values, name='values') - values = gen_array_ops.reshape(values, [-1]) - value_range = ops.convert_to_tensor(value_range, name='value_range') - nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins') - nbins_float = math_ops.cast(nbins, values.dtype) - - # Map tensor values that fall within value_range to [0, 1]. - scaled_values = math_ops.truediv( - values - value_range[0], - value_range[1] - value_range[0], - name='scaled_values') - - # map tensor values within the open interval value_range to {0,.., nbins-1}, - # values outside the open interval will be zero or less, or nbins or more. - indices = math_ops.floor(nbins_float * scaled_values, name='indices') - - # Clip edge cases (e.g. value = value_range[1]) or "outliers." - indices = math_ops.cast( - clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32) - - return math_ops.unsorted_segment_sum( - array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) - - -def _determine_partitioned_axis(partitioned_variable): - partitioned_axis = 0 - concatenated_variable_shape = partitioned_variable.get_shape() - for partition in partitioned_variable: - partition_shape = partition.get_shape() - maybe_partitioned_axis = np.less(partition_shape, - concatenated_variable_shape) - # Sanity check: make sure number of partitioned axis == 1 - if np.count_nonzero(maybe_partitioned_axis) != 1: - raise ValueError('Number of partitioned axes %s not equal to 1' % - np.count_nonzero(maybe_partitioned_axis)) - partitioned_axis = np.where(maybe_partitioned_axis)[0][0] - return partitioned_axis - - -def _variable_assign(var, new_value): - return state_ops.assign(var, new_value, name=var.op.name + '_assign') - - -def _partitioned_variable_assign(partitioned_var, new_value): - """Assign op for partitioned variables. - - Args: - partitioned_var: A partitioned tensorflow variable - new_value: Value to be assigned to the variable var - - Returns: - A tensorflow op that groups the assign ops for each of the variable slices - """ - # Determine which axis was used to partition the variable. Currently - # tensorflow allows partitioning variable only along 1 axis. - axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis( - partitioned_var) - - partition_sizes = np.array( - [partition.get_shape()[axis] for partition in partitioned_var]) - new_partitioned_values = array_ops.split( - new_value, - ops.convert_to_tensor(partition_sizes, dtype=np.int32), - axis=axis) - op_list = [] - for partition in partitioned_var: - op_list.append( - _variable_assign(partition, new_partitioned_values[len(op_list)])) - return control_flow_ops.group( - *op_list, name=partitioned_var.name + '_group_assign') - - def apply_mask(x, scope=''): """Apply mask to a given weight tensor. Args: x: Input weight tensor - scope: The current variable scope. Defaults to "" + scope: The current variable scope. Defaults to "". Returns: Tensor representing masked_weights """ - mask = _weight_mask_variable(x, scope) - threshold = _weight_threshold_variable(x, scope) + mask = pruning_utils.weight_mask_variable(x, scope) + threshold = pruning_utils.weight_threshold_variable(x, scope) # Add masked_weights in the weights namescope so as to make it easier # for the quantization library to add quant ops. masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME) @@ -335,6 +179,8 @@ def get_pruning_hparams(): sparsity_function_exponent: float exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning + use_tpu: False + Indicates whether to use TPU We use the following sparsity function: @@ -357,7 +203,7 @@ def get_pruning_hparams(): do_not_prune=[''], threshold_decay=0.9, pruning_frequency=10, - nbins=255, + nbins=256, block_height=1, block_width=1, block_pooling_function='AVG', @@ -365,7 +211,8 @@ def get_pruning_hparams(): target_sparsity=0.5, sparsity_function_begin_step=0, sparsity_function_end_step=100, - sparsity_function_exponent=3) + sparsity_function_exponent=3, + use_tpu=False) class Pruning(object): @@ -414,7 +261,7 @@ class Pruning(object): if graph_global_step is None: graph_global_step = training_util.get_global_step() - return math_ops.cast(graph_global_step, np.int32) + return math_ops.cast(graph_global_step, dtypes.int32) def _setup_sparsity(self): begin_step = self._spec.sparsity_function_begin_step @@ -429,13 +276,13 @@ class Pruning(object): (begin_step, end_step)) with ops.name_scope(self._spec.name): - p = math_ops.minimum(1.0, - math_ops.maximum( - 0.0, - math_ops.div( - math_ops.cast(self._global_step - begin_step, - np.float32), - end_step - begin_step))) + p = math_ops.minimum( + 1.0, + math_ops.maximum( + 0.0, + math_ops.div( + math_ops.cast(self._global_step - begin_step, dtypes.float32), + end_step - begin_step))) sparsity = math_ops.add( math_ops.multiply(initial_sparsity - target_sparsity, math_ops.pow(1 - p, exponent)), @@ -445,17 +292,18 @@ class Pruning(object): return sparsity def _setup_last_update_step(self): - with variable_scope.variable_scope(self._spec.name) as scope: + with variable_scope.variable_scope( + self._spec.name, use_resource=self._spec.use_tpu) as scope: try: last_update_step = variable_scope.get_variable( 'last_mask_update_step', [], initializer=init_ops.zeros_initializer(), trainable=False, - dtype=np.int32) + dtype=dtypes.int32) except ValueError: scope.reuse_variables() last_update_step = variable_scope.get_variable( - 'last_mask_update_step', dtype=np.int32) + 'last_mask_update_step', dtype=dtypes.int32) return last_update_step def _exists_in_do_not_prune_list(self, tensor_name): @@ -497,18 +345,16 @@ class Pruning(object): with ops.name_scope(weights.op.name + '_pruning_ops'): abs_weights = math_ops.abs(weights) max_value = math_ops.reduce_max(abs_weights) - histogram = _histogram( - abs_weights, [0.0, max_value], - nbins=self._spec.nbins, - dtype=np.float32) + cdf_fn = pruning_utils.compute_cdf_from_histogram + if self._spec.use_tpu: + cdf_fn = pruning_utils.compute_cdf - cdf = math_ops.cumsum(histogram) - norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram)) + norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins) current_threshold = math_ops.multiply( math_ops.div( math_ops.reduce_sum( math_ops.cast( - math_ops.less(norm_cdf, self._sparsity), np.float32)), + math_ops.less(norm_cdf, self._sparsity), dtypes.float32)), float(self._spec.nbins)), max_value) smoothed_threshold = math_ops.add_n([ @@ -516,7 +362,7 @@ class Pruning(object): math_ops.multiply(threshold, self._spec.threshold_decay) ]) new_mask = math_ops.cast( - math_ops.greater(abs_weights, smoothed_threshold), np.float32) + math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32) return smoothed_threshold, new_mask def _maybe_update_block_mask(self, weights, threshold): @@ -572,8 +418,8 @@ class Pruning(object): new_mask, [pooled_weights.get_shape()[1], pooled_weights.get_shape()[2]]) - updated_mask = _kronecker_product(reshaped_mask, - array_ops.ones(self._block_dim)) + updated_mask = pruning_utils.kronecker_product( + reshaped_mask, array_ops.ones(self._block_dim)) sliced_mask = array_ops.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], @@ -608,11 +454,12 @@ class Pruning(object): continue new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold) - self._assign_ops.append(_variable_assign(threshold, new_threshold)) + self._assign_ops.append( + pruning_utils.variable_assign(threshold, new_threshold)) self._assign_ops.append( - _partitioned_variable_assign(mask, new_mask) - if is_partitioned else _variable_assign(mask, new_mask)) + pruning_utils.partitioned_variable_assign(mask, new_mask) + if is_partitioned else pruning_utils.variable_assign(mask, new_mask)) def mask_update_op(self): with ops.name_scope(self._spec.name): diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py index 89e6571319..f80b7c52c0 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_test.py @@ -110,12 +110,12 @@ class PruningTest(test.TestCase): self.assertAllEqual(np.count_nonzero(masked_weights_val), 100) session.run(mask_update_op) masked_weights_val = masked_weights.eval() - self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) def _blockMasking(self, hparams, weights, expected_mask): threshold = variables.Variable(0.0, name="threshold") - sparsity = variables.Variable(0.51, name="sparsity") + sparsity = variables.Variable(0.5, name="sparsity") test_spec = ",".join(hparams) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) @@ -138,7 +138,8 @@ class PruningTest(test.TestCase): weights_max = constant_op.constant( [[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0, -0.4]]) - expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] + expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], + [1., 1., 1., 1.], [1., 1., 1., 1.]] self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, expected_mask) @@ -155,7 +156,8 @@ class PruningTest(test.TestCase): weights_max = constant_op.constant( [[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0], [0.0, -0.3, 0.0, -0.4]]]) - expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]] + expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0], + [1., 1., 1., 1.], [1., 1., 1., 1.]]] self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max, expected_mask) @@ -178,11 +180,12 @@ class PruningTest(test.TestCase): masked_weights_val = masked_weights.eval() session.run(mask_update_op) masked_weights_val = masked_weights.eval() - self.assertAllEqual(np.count_nonzero(masked_weights_val), 51) + self.assertAllEqual(np.count_nonzero(masked_weights_val), 50) def testConditionalMaskUpdate(self): param_list = [ - "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6" + "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6", + "nbins=100" ] test_spec = ",".join(param_list) pruning_hparams = pruning.get_pruning_hparams().parse(test_spec) diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py new file mode 100644 index 0000000000..56d3dcef20 --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -0,0 +1,269 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Utility functions for adding pruning related ops to the graph. +""" +# pylint: disable=missing-docstring +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import clip_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import state_ops +from tensorflow.python.ops import variable_scope + +_NBINS = 256 + + +def weight_mask_variable(var, scope): + """Create a mask for the weights. + + This function adds a variable 'mask' to the graph. + + Args: + var: the weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + the mask variable of the same size and shape as var, initialized to all 1s. + """ + with variable_scope.variable_scope(scope): + mask = variable_scope.get_variable( + 'mask', + var.get_shape(), + initializer=init_ops.ones_initializer(), + trainable=False, + dtype=var.dtype) + return mask + + +def weight_threshold_variable(var, scope): + """Create a scalar threshold for the weights. + + This function adds a variable + 'threshold' to the graph. + + Args: + var: The weight variable that needs to be masked + scope: The variable scope of the variable var + + Returns: + a scalar threshold variable initialized to 0. + """ + with variable_scope.variable_scope(scope): + threshold = variable_scope.get_variable( + 'threshold', [], + initializer=init_ops.zeros_initializer(), + trainable=False, + dtype=var.dtype) + return threshold + + +def kronecker_product(mat1, mat2): + """Computes the Kronecker product of two matrices mat1 and mat2. + + Args: + mat1: A matrix of size m x n + mat2: A matrix of size p x q + Returns: + Kronecker product of matrices mat1 and mat2 of size mp x nq + """ + + m1, n1 = mat1.get_shape().as_list() + mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1]) + m2, n2 = mat2.get_shape().as_list() + mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2]) + return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) + + +def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): + """Return histogram of values. + + Given the tensor `values`, this operation returns a rank 1 histogram counting + the number of entries in `values` that fell into every bin. The bins are + equal width and determined by the arguments `value_range` and `nbins`. + + Args: + values: Numeric `Tensor`. + value_range: Shape [2] `Tensor` of same `dtype` as `values`. + values <= value_range[0] will be mapped to hist[0], + values >= value_range[1] will be mapped to hist[-1]. + nbins: Scalar `int32 Tensor`. Number of histogram bins. + dtype: dtype for returned histogram. + name: A name for this operation (defaults to 'histogram'). + + Returns: + A 1-D `Tensor` holding histogram of values. + + """ + with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope: + values = ops.convert_to_tensor(values, name='values') + values = array_ops.reshape(values, [-1]) + value_range = ops.convert_to_tensor(value_range, name='value_range') + nbins_float = np.float32(nbins) + + # Map tensor values that fall within value_range to [0, 1]. + scaled_values = math_ops.truediv( + values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') + + # map tensor values within the open interval value_range to {0,.., nbins-1}, + # values outside the open interval will be zero or less, or nbins or more. + indices = math_ops.floor(nbins_float * scaled_values, name='indices') + + # Clip edge cases (e.g. value = value_range[1]) or "outliers." + indices = math_ops.cast( + clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) + + return math_ops.unsorted_segment_sum( + array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope) + + +def compute_cdf_from_histogram(values, value_range, **kwargs): + """Returns the normalized cumulative distribution of the given values tensor. + + Computes the histogram and uses tf.cumsum to arrive at cdf + + Args: + values: Numeric `Tensor`. + value_range: Shape [2] `Tensor` of same `dtype` as `values`. + **kwargs: keyword arguments: nbins, name + + Returns: + A 1-D `Tensor` holding normalized cdf of values. + + """ + nbins = kwargs.get('nbins', _NBINS) + name = kwargs.get('name', None) + with ops.name_scope(name, 'cdf', [values, value_range, nbins]): + histogram = _histogram( + values, value_range, dtype=dtypes.float32, nbins=nbins) + cdf = math_ops.cumsum(histogram) + return math_ops.div(cdf, math_ops.reduce_max(cdf)) + + +def compute_cdf(values, value_range, **kwargs): + """Returns the normalized cumulative distribution of the given values tensor. + + Uses tf.while_loop to directly compute the cdf of the values. Number of bins + for histogram is fixed at _NBINS=255 + + Args: + values: Numeric `Tensor`. + value_range: Shape [2] `Tensor` of same `dtype` as `values` + **kwargs: keyword arguments: name + + Returns: + A 1-D `Tensor` holding normalized cdf of values. + + """ + nbins = _NBINS + name = kwargs.get('name', None) + with ops.name_scope(name, 'cdf', [values, value_range, nbins]): + values = ops.convert_to_tensor(values, name='values') + value_range = ops.convert_to_tensor(value_range, name='value_range') + nbins_float = np.float32(nbins) + + # Map tensor values that fall within value_range to [0, 1]. + scaled_values = math_ops.truediv( + values - value_range[0], + value_range[1] - value_range[0], + name='scaled_values') + + # map tensor values within the open interval value_range to {0,.., nbins-1}, + # values outside the open interval will be zero or less, or nbins or more. + indices = math_ops.floor(nbins_float * scaled_values, name='indices') + + # Clip edge cases (e.g. value = value_range[1]) or "outliers." + indices = math_ops.cast( + clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32) + + cdf = array_ops.zeros(nbins) + i = constant_op.constant(0) + + def loop_cond(loop_count, _): + return math_ops.less(loop_count, nbins) + + def loop_body(loop_count, cdf): + temp = math_ops.reduce_sum( + math_ops.cast( + math_ops.less_equal(indices, loop_count), dtypes.float32)) + cdf = math_ops.add( + cdf, + array_ops.one_hot( + loop_count, depth=_NBINS, on_value=temp, off_value=0.0)) + return [loop_count + 1, cdf] + + _, cdf = control_flow_ops.while_loop( + loop_cond, loop_body, [i, cdf], maximum_iterations=nbins) + + return math_ops.div(cdf, math_ops.reduce_max(cdf)) + + +def determine_partitioned_axis(partitioned_variable): + partitioned_axis = 0 + concatenated_variable_shape = partitioned_variable.get_shape() + for partition in partitioned_variable: + partition_shape = partition.get_shape() + maybe_partitioned_axis = np.less(partition_shape, + concatenated_variable_shape) + # Sanity check: make sure number of partitioned axis == 1 + if np.count_nonzero(maybe_partitioned_axis) != 1: + raise ValueError('Number of partitioned axes %s not equal to 1' % + np.count_nonzero(maybe_partitioned_axis)) + partitioned_axis = np.where(maybe_partitioned_axis)[0][0] + return partitioned_axis + + +def variable_assign(var, new_value): + return state_ops.assign(var, new_value, name=var.op.name + '_assign') + + +def partitioned_variable_assign(partitioned_var, new_value): + """Assign op for partitioned variables. + + Args: + partitioned_var: A partitioned tensorflow variable + new_value: Value to be assigned to the variable var + + Returns: + A tensorflow op that groups the assign ops for each of the variable slices + """ + # Determine which axis was used to partition the variable. Currently + # tensorflow allows partitioning variable only along 1 axis. + axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis( + partitioned_var) + + partition_sizes = np.array( + [partition.get_shape()[axis] for partition in partitioned_var]) + new_partitioned_values = array_ops.split( + new_value, + ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32), + axis=axis) + op_list = [] + for partition in partitioned_var: + op_list.append( + variable_assign(partition, new_partitioned_values[len(op_list)])) + return control_flow_ops.group( + *op_list, name=partitioned_var.name + '_group_assign') diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py new file mode 100644 index 0000000000..10e1dd0a8e --- /dev/null +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -0,0 +1,86 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for utility functions in pruning_utils.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.model_pruning.python import pruning_utils +from tensorflow.python.framework import constant_op +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.platform import test + + +class PruningUtilsTest(test.TestCase): + + def testHistogram(self): + width = 10 + height = 10 + nbins = 100 + expected_histogram = np.full(nbins, 1.0) + init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height)) + weights = variable_scope.get_variable( + "weights", [width, height], initializer=init) + histogram = pruning_utils._histogram( + weights, [0, 1.0], nbins, dtype=np.float32) + with self.test_session(): + variables.global_variables_initializer().run() + computed_histogram = histogram.eval() + self.assertAllEqual(expected_histogram, computed_histogram) + + def testCDF(self): + nbins = 5 + weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100]) + abs_weights = math_ops.abs(weights) + norm_cdf = pruning_utils.compute_cdf_from_histogram( + abs_weights, [0.0, 5.0], nbins=nbins) + expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32) + with self.test_session() as sess: + variables.global_variables_initializer().run() + norm_cdf_val = sess.run(norm_cdf) + self.assertAllEqual(len(norm_cdf_val), nbins) + self.assertAllEqual(expected_cdf, norm_cdf_val) + + def _compare_cdf(self, values): + abs_values = math_ops.abs(values) + max_value = math_ops.reduce_max(abs_values) + with self.test_session(): + variables.global_variables_initializer().run() + cdf_from_histogram = pruning_utils.compute_cdf_from_histogram( + abs_values, [0.0, max_value], nbins=pruning_utils._NBINS) + cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) + return cdf.eval(), cdf_from_histogram.eval() + + def testCDFEquivalence2D(self): + width = 100 + height = 100 + weights = variable_scope.get_variable("weights", shape=[width, height]) + cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) + self.assertAllEqual(cdf_val, cdf_from_histogram_val) + + def testCDFEquivalence4D(self): + weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) + cdf_val, cdf_from_histogram_val = self._compare_cdf(weights) + self.assertAllEqual(cdf_val, cdf_from_histogram_val) + + +if __name__ == "__main__": + test.main() |