aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-13 00:03:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-13 00:06:16 -0700
commit73cc1d5b6f95ff56207e4c42b62d383c2427fb75 (patch)
tree15c5cb772ceb2851ce145df1a2e2782392a6b5c5 /tensorflow/contrib/model_pruning
parent68f0f1aadb07ed1e7449b969d8807b5f662be33a (diff)
-- Add a new histogram/cdf computation method compatible with the TPU.
-- Refactor utility functions into pruning_utils.py and add tests PiperOrigin-RevId: 192727737
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/BUILD24
-rw-r--r--tensorflow/contrib/model_pruning/README.md2
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py237
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_test.py15
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py269
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py86
6 files changed, 430 insertions, 203 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index f50575b2cf..54bd39afac 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -72,15 +72,37 @@ py_library(
)
py_library(
+ name = "pruning_utils",
+ srcs = ["python/pruning_utils.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/python:platform",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
name = "pruning",
srcs = ["python/pruning.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
":core_layers",
+ ":pruning_utils",
"//tensorflow/contrib/training:training_py",
"//tensorflow/python:platform",
- "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "pruning_utils_test",
+ size = "small",
+ srcs = ["python/pruning_utils_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":pruning_utils",
+ "//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 52b659c69f..86f4fd6adf 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -45,7 +45,7 @@ The pruning library allows for specification of the following hyper parameters:
| do_not_prune | list of strings | [""] | list of layers names that are not pruned |
| threshold_decay | float | 0.9 | The decay factor to use for exponential decay of the thresholds |
| pruning_frequency | integer | 10 | How often should the masks be updated? (in # of global_steps) |
-| nbins | integer | 255 | Number of bins to use for histogram computation |
+| nbins | integer | 256 | Number of bins to use for histogram computation |
| block_height|integer | 1 | Number of rows in a block for block sparse matrices|
| block_width |integer | 1 | Number of cols in a block for block sparse matrices|
| block_pooling_function| string | AVG | The function to use to pool weight values in a block: average (AVG) or max (MAX)|
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index 5146a4a2de..ea6032e588 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -33,12 +33,14 @@
# Returns a list of all the weight tensors that have been masked
get_weights()
- The Pruning class uses a proto (defined in pruning.proto) to set up the
- parameters for a pruning specification. Here's a typical usage:
+ The Pruning class uses a tf.hparams object to set up the
+ parameters for a model pruning. Here's a typical usage:
- # Initialize a pruning spec from a proto
- pruning_spec = '/tmp/pruning.pb'
- p = Pruning(pruning_spec)
+ # Parse pruning hyperparameters
+ pruning_hparams = pruning.get_pruning_hparams().parse(FLAGS.pruning_hparams)
+
+ # Create a pruning object using the pruning_hparams
+ p = pruning.Pruning(pruning_hparams)
# Add mask update ops to the graph
mask_update_op = p.conditional_mask_update_op()
@@ -51,24 +53,20 @@
# An object of the pruning also accepts externally defined sparsity:
sparsity = tf.Variable(0.5, name = "ConstantSparsity")
- pruning_spec = '/tmp/pruning.pb'
- p = Pruning(pruning_spec, sparsity=sparsity)
-
+ p = pruning.Pruning(pruning_hparams, sparsity=sparsity)
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
+from tensorflow.contrib.model_pruning.python import pruning_utils
from tensorflow.contrib.model_pruning.python.layers import core_layers as core
from tensorflow.contrib.training.python.training import hparam
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_impl
@@ -87,172 +85,18 @@ _WEIGHT_COLLECTION = core.WEIGHT_COLLECTION
_MASKED_WEIGHT_NAME = core.MASKED_WEIGHT_NAME
-def _weight_mask_variable(var, scope):
- """Create a mask for the weights.
-
- This function adds a variable 'mask' to the graph.
-
- Args:
- var: the weight variable that needs to be masked
- scope: The variable scope of the variable var
-
- Returns:
- the mask variable of the same size and shape as var, initialized to all 1s.
- """
- with variable_scope.variable_scope(scope):
- mask = variable_scope.get_variable(
- 'mask',
- var.get_shape(),
- initializer=init_ops.ones_initializer(),
- trainable=False,
- dtype=var.dtype)
- return mask
-
-
-def _weight_threshold_variable(var, scope):
- """Create a scalar threshold for the weights.
-
- This function adds a variable
- 'threshold' to the graph.
-
- Args:
- var: The weight variable that needs to be masked
- scope: The variable scope of the variable var
-
- Returns:
- a scalar threshold variable initialized to 0.
- """
- with variable_scope.variable_scope(scope):
- threshold = variable_scope.get_variable(
- 'threshold', [],
- initializer=init_ops.zeros_initializer(),
- trainable=False,
- dtype=var.dtype)
- return threshold
-
-
-def _kronecker_product(mat1, mat2):
- """Computes the Kronecker product of two matrices mat1 and mat2.
-
- Args:
- mat1: A matrix of size m x n
- mat2: A matrix of size p x q
- Returns:
- Kronecker product of matrices mat1 and mat2 of size mp x nq
- """
-
- m1, n1 = mat1.get_shape().as_list()
- mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
- m2, n2 = mat2.get_shape().as_list()
- mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
- return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
-
-
-def _histogram(values, value_range, nbins=100, dtype=np.int32, name=None):
- """Return histogram of values.
-
- Given the tensor `values`, this operation returns a rank 1 histogram counting
- the number of entries in `values` that fell into every bin. The bins are
- equal width and determined by the arguments `value_range` and `nbins`.
-
- Args:
- values: Numeric `Tensor`.
- value_range: Shape [2] `Tensor` of same `dtype` as `values`.
- values <= value_range[0] will be mapped to hist[0],
- values >= value_range[1] will be mapped to hist[-1].
- nbins: Scalar `int32 Tensor`. Number of histogram bins.
- dtype: dtype for returned histogram.
- name: A name for this operation (defaults to 'histogram').
-
- Returns:
- A 1-D `Tensor` holding histogram of values.
-
- """
- with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
- values = ops.convert_to_tensor(values, name='values')
- values = gen_array_ops.reshape(values, [-1])
- value_range = ops.convert_to_tensor(value_range, name='value_range')
- nbins = ops.convert_to_tensor(nbins, dtype=np.int32, name='nbins')
- nbins_float = math_ops.cast(nbins, values.dtype)
-
- # Map tensor values that fall within value_range to [0, 1].
- scaled_values = math_ops.truediv(
- values - value_range[0],
- value_range[1] - value_range[0],
- name='scaled_values')
-
- # map tensor values within the open interval value_range to {0,.., nbins-1},
- # values outside the open interval will be zero or less, or nbins or more.
- indices = math_ops.floor(nbins_float * scaled_values, name='indices')
-
- # Clip edge cases (e.g. value = value_range[1]) or "outliers."
- indices = math_ops.cast(
- clip_ops.clip_by_value(indices, 0, nbins_float - 1), np.int32)
-
- return math_ops.unsorted_segment_sum(
- array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
-
-
-def _determine_partitioned_axis(partitioned_variable):
- partitioned_axis = 0
- concatenated_variable_shape = partitioned_variable.get_shape()
- for partition in partitioned_variable:
- partition_shape = partition.get_shape()
- maybe_partitioned_axis = np.less(partition_shape,
- concatenated_variable_shape)
- # Sanity check: make sure number of partitioned axis == 1
- if np.count_nonzero(maybe_partitioned_axis) != 1:
- raise ValueError('Number of partitioned axes %s not equal to 1' %
- np.count_nonzero(maybe_partitioned_axis))
- partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
- return partitioned_axis
-
-
-def _variable_assign(var, new_value):
- return state_ops.assign(var, new_value, name=var.op.name + '_assign')
-
-
-def _partitioned_variable_assign(partitioned_var, new_value):
- """Assign op for partitioned variables.
-
- Args:
- partitioned_var: A partitioned tensorflow variable
- new_value: Value to be assigned to the variable var
-
- Returns:
- A tensorflow op that groups the assign ops for each of the variable slices
- """
- # Determine which axis was used to partition the variable. Currently
- # tensorflow allows partitioning variable only along 1 axis.
- axis = 0 if len(partitioned_var) == 1 else _determine_partitioned_axis(
- partitioned_var)
-
- partition_sizes = np.array(
- [partition.get_shape()[axis] for partition in partitioned_var])
- new_partitioned_values = array_ops.split(
- new_value,
- ops.convert_to_tensor(partition_sizes, dtype=np.int32),
- axis=axis)
- op_list = []
- for partition in partitioned_var:
- op_list.append(
- _variable_assign(partition, new_partitioned_values[len(op_list)]))
- return control_flow_ops.group(
- *op_list, name=partitioned_var.name + '_group_assign')
-
-
def apply_mask(x, scope=''):
"""Apply mask to a given weight tensor.
Args:
x: Input weight tensor
- scope: The current variable scope. Defaults to ""
+ scope: The current variable scope. Defaults to "".
Returns:
Tensor representing masked_weights
"""
- mask = _weight_mask_variable(x, scope)
- threshold = _weight_threshold_variable(x, scope)
+ mask = pruning_utils.weight_mask_variable(x, scope)
+ threshold = pruning_utils.weight_threshold_variable(x, scope)
# Add masked_weights in the weights namescope so as to make it easier
# for the quantization library to add quant ops.
masked_weights = math_ops.multiply(mask, x, _MASKED_WEIGHT_NAME)
@@ -335,6 +179,8 @@ def get_pruning_hparams():
sparsity_function_exponent: float
exponent = 1 is linearly varying sparsity between initial and final.
exponent > 1 varies more slowly towards the end than the beginning
+ use_tpu: False
+ Indicates whether to use TPU
We use the following sparsity function:
@@ -357,7 +203,7 @@ def get_pruning_hparams():
do_not_prune=[''],
threshold_decay=0.9,
pruning_frequency=10,
- nbins=255,
+ nbins=256,
block_height=1,
block_width=1,
block_pooling_function='AVG',
@@ -365,7 +211,8 @@ def get_pruning_hparams():
target_sparsity=0.5,
sparsity_function_begin_step=0,
sparsity_function_end_step=100,
- sparsity_function_exponent=3)
+ sparsity_function_exponent=3,
+ use_tpu=False)
class Pruning(object):
@@ -414,7 +261,7 @@ class Pruning(object):
if graph_global_step is None:
graph_global_step = training_util.get_global_step()
- return math_ops.cast(graph_global_step, np.int32)
+ return math_ops.cast(graph_global_step, dtypes.int32)
def _setup_sparsity(self):
begin_step = self._spec.sparsity_function_begin_step
@@ -429,13 +276,13 @@ class Pruning(object):
(begin_step, end_step))
with ops.name_scope(self._spec.name):
- p = math_ops.minimum(1.0,
- math_ops.maximum(
- 0.0,
- math_ops.div(
- math_ops.cast(self._global_step - begin_step,
- np.float32),
- end_step - begin_step)))
+ p = math_ops.minimum(
+ 1.0,
+ math_ops.maximum(
+ 0.0,
+ math_ops.div(
+ math_ops.cast(self._global_step - begin_step, dtypes.float32),
+ end_step - begin_step)))
sparsity = math_ops.add(
math_ops.multiply(initial_sparsity - target_sparsity,
math_ops.pow(1 - p, exponent)),
@@ -445,17 +292,18 @@ class Pruning(object):
return sparsity
def _setup_last_update_step(self):
- with variable_scope.variable_scope(self._spec.name) as scope:
+ with variable_scope.variable_scope(
+ self._spec.name, use_resource=self._spec.use_tpu) as scope:
try:
last_update_step = variable_scope.get_variable(
'last_mask_update_step', [],
initializer=init_ops.zeros_initializer(),
trainable=False,
- dtype=np.int32)
+ dtype=dtypes.int32)
except ValueError:
scope.reuse_variables()
last_update_step = variable_scope.get_variable(
- 'last_mask_update_step', dtype=np.int32)
+ 'last_mask_update_step', dtype=dtypes.int32)
return last_update_step
def _exists_in_do_not_prune_list(self, tensor_name):
@@ -497,18 +345,16 @@ class Pruning(object):
with ops.name_scope(weights.op.name + '_pruning_ops'):
abs_weights = math_ops.abs(weights)
max_value = math_ops.reduce_max(abs_weights)
- histogram = _histogram(
- abs_weights, [0.0, max_value],
- nbins=self._spec.nbins,
- dtype=np.float32)
+ cdf_fn = pruning_utils.compute_cdf_from_histogram
+ if self._spec.use_tpu:
+ cdf_fn = pruning_utils.compute_cdf
- cdf = math_ops.cumsum(histogram)
- norm_cdf = math_ops.div(cdf, math_ops.reduce_sum(histogram))
+ norm_cdf = cdf_fn(abs_weights, [0.0, max_value], nbins=self._spec.nbins)
current_threshold = math_ops.multiply(
math_ops.div(
math_ops.reduce_sum(
math_ops.cast(
- math_ops.less(norm_cdf, self._sparsity), np.float32)),
+ math_ops.less(norm_cdf, self._sparsity), dtypes.float32)),
float(self._spec.nbins)), max_value)
smoothed_threshold = math_ops.add_n([
@@ -516,7 +362,7 @@ class Pruning(object):
math_ops.multiply(threshold, self._spec.threshold_decay)
])
new_mask = math_ops.cast(
- math_ops.greater(abs_weights, smoothed_threshold), np.float32)
+ math_ops.greater(abs_weights, smoothed_threshold), dtypes.float32)
return smoothed_threshold, new_mask
def _maybe_update_block_mask(self, weights, threshold):
@@ -572,8 +418,8 @@ class Pruning(object):
new_mask,
[pooled_weights.get_shape()[1],
pooled_weights.get_shape()[2]])
- updated_mask = _kronecker_product(reshaped_mask,
- array_ops.ones(self._block_dim))
+ updated_mask = pruning_utils.kronecker_product(
+ reshaped_mask, array_ops.ones(self._block_dim))
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
@@ -608,11 +454,12 @@ class Pruning(object):
continue
new_threshold, new_mask = self._maybe_update_block_mask(weight, threshold)
- self._assign_ops.append(_variable_assign(threshold, new_threshold))
+ self._assign_ops.append(
+ pruning_utils.variable_assign(threshold, new_threshold))
self._assign_ops.append(
- _partitioned_variable_assign(mask, new_mask)
- if is_partitioned else _variable_assign(mask, new_mask))
+ pruning_utils.partitioned_variable_assign(mask, new_mask)
+ if is_partitioned else pruning_utils.variable_assign(mask, new_mask))
def mask_update_op(self):
with ops.name_scope(self._spec.name):
diff --git a/tensorflow/contrib/model_pruning/python/pruning_test.py b/tensorflow/contrib/model_pruning/python/pruning_test.py
index 89e6571319..f80b7c52c0 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_test.py
@@ -110,12 +110,12 @@ class PruningTest(test.TestCase):
self.assertAllEqual(np.count_nonzero(masked_weights_val), 100)
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
- self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+ self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
def _blockMasking(self, hparams, weights, expected_mask):
threshold = variables.Variable(0.0, name="threshold")
- sparsity = variables.Variable(0.51, name="sparsity")
+ sparsity = variables.Variable(0.5, name="sparsity")
test_spec = ",".join(hparams)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
@@ -138,7 +138,8 @@ class PruningTest(test.TestCase):
weights_max = constant_op.constant(
[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
[0.0, -0.3, 0.0, -0.4]])
- expected_mask = [[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]
+ expected_mask = [[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
+ [1., 1., 1., 1.], [1., 1., 1., 1.]]
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
@@ -155,7 +156,8 @@ class PruningTest(test.TestCase):
weights_max = constant_op.constant(
[[[0.1, 0.0, 0.2, 0.0], [0.0, -0.1, 0.0, -0.2], [0.3, 0.0, 0.4, 0.0],
[0.0, -0.3, 0.0, -0.4]]])
- expected_mask = [[[0, 0, 0, 0], [0, 0, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]]]
+ expected_mask = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0],
+ [1., 1., 1., 1.], [1., 1., 1., 1.]]]
self._blockMasking(param_list + ["block_pooling_function=MAX"], weights_max,
expected_mask)
@@ -178,11 +180,12 @@ class PruningTest(test.TestCase):
masked_weights_val = masked_weights.eval()
session.run(mask_update_op)
masked_weights_val = masked_weights.eval()
- self.assertAllEqual(np.count_nonzero(masked_weights_val), 51)
+ self.assertAllEqual(np.count_nonzero(masked_weights_val), 50)
def testConditionalMaskUpdate(self):
param_list = [
- "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6"
+ "pruning_frequency=2", "begin_pruning_step=1", "end_pruning_step=6",
+ "nbins=100"
]
test_spec = ",".join(param_list)
pruning_hparams = pruning.get_pruning_hparams().parse(test_spec)
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
new file mode 100644
index 0000000000..56d3dcef20
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -0,0 +1,269 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Utility functions for adding pruning related ops to the graph.
+"""
+# pylint: disable=missing-docstring
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+
+_NBINS = 256
+
+
+def weight_mask_variable(var, scope):
+ """Create a mask for the weights.
+
+ This function adds a variable 'mask' to the graph.
+
+ Args:
+ var: the weight variable that needs to be masked
+ scope: The variable scope of the variable var
+
+ Returns:
+ the mask variable of the same size and shape as var, initialized to all 1s.
+ """
+ with variable_scope.variable_scope(scope):
+ mask = variable_scope.get_variable(
+ 'mask',
+ var.get_shape(),
+ initializer=init_ops.ones_initializer(),
+ trainable=False,
+ dtype=var.dtype)
+ return mask
+
+
+def weight_threshold_variable(var, scope):
+ """Create a scalar threshold for the weights.
+
+ This function adds a variable
+ 'threshold' to the graph.
+
+ Args:
+ var: The weight variable that needs to be masked
+ scope: The variable scope of the variable var
+
+ Returns:
+ a scalar threshold variable initialized to 0.
+ """
+ with variable_scope.variable_scope(scope):
+ threshold = variable_scope.get_variable(
+ 'threshold', [],
+ initializer=init_ops.zeros_initializer(),
+ trainable=False,
+ dtype=var.dtype)
+ return threshold
+
+
+def kronecker_product(mat1, mat2):
+ """Computes the Kronecker product of two matrices mat1 and mat2.
+
+ Args:
+ mat1: A matrix of size m x n
+ mat2: A matrix of size p x q
+ Returns:
+ Kronecker product of matrices mat1 and mat2 of size mp x nq
+ """
+
+ m1, n1 = mat1.get_shape().as_list()
+ mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
+ m2, n2 = mat2.get_shape().as_list()
+ mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
+ return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+
+
+def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
+ """Return histogram of values.
+
+ Given the tensor `values`, this operation returns a rank 1 histogram counting
+ the number of entries in `values` that fell into every bin. The bins are
+ equal width and determined by the arguments `value_range` and `nbins`.
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+ values <= value_range[0] will be mapped to hist[0],
+ values >= value_range[1] will be mapped to hist[-1].
+ nbins: Scalar `int32 Tensor`. Number of histogram bins.
+ dtype: dtype for returned histogram.
+ name: A name for this operation (defaults to 'histogram').
+
+ Returns:
+ A 1-D `Tensor` holding histogram of values.
+
+ """
+ with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
+ values = ops.convert_to_tensor(values, name='values')
+ values = array_ops.reshape(values, [-1])
+ value_range = ops.convert_to_tensor(value_range, name='value_range')
+ nbins_float = np.float32(nbins)
+
+ # Map tensor values that fall within value_range to [0, 1].
+ scaled_values = math_ops.truediv(
+ values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
+
+ # map tensor values within the open interval value_range to {0,.., nbins-1},
+ # values outside the open interval will be zero or less, or nbins or more.
+ indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+ # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+ indices = math_ops.cast(
+ clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+
+ return math_ops.unsorted_segment_sum(
+ array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)
+
+
+def compute_cdf_from_histogram(values, value_range, **kwargs):
+ """Returns the normalized cumulative distribution of the given values tensor.
+
+ Computes the histogram and uses tf.cumsum to arrive at cdf
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+ **kwargs: keyword arguments: nbins, name
+
+ Returns:
+ A 1-D `Tensor` holding normalized cdf of values.
+
+ """
+ nbins = kwargs.get('nbins', _NBINS)
+ name = kwargs.get('name', None)
+ with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
+ histogram = _histogram(
+ values, value_range, dtype=dtypes.float32, nbins=nbins)
+ cdf = math_ops.cumsum(histogram)
+ return math_ops.div(cdf, math_ops.reduce_max(cdf))
+
+
+def compute_cdf(values, value_range, **kwargs):
+ """Returns the normalized cumulative distribution of the given values tensor.
+
+ Uses tf.while_loop to directly compute the cdf of the values. Number of bins
+ for histogram is fixed at _NBINS=255
+
+ Args:
+ values: Numeric `Tensor`.
+ value_range: Shape [2] `Tensor` of same `dtype` as `values`
+ **kwargs: keyword arguments: name
+
+ Returns:
+ A 1-D `Tensor` holding normalized cdf of values.
+
+ """
+ nbins = _NBINS
+ name = kwargs.get('name', None)
+ with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
+ values = ops.convert_to_tensor(values, name='values')
+ value_range = ops.convert_to_tensor(value_range, name='value_range')
+ nbins_float = np.float32(nbins)
+
+ # Map tensor values that fall within value_range to [0, 1].
+ scaled_values = math_ops.truediv(
+ values - value_range[0],
+ value_range[1] - value_range[0],
+ name='scaled_values')
+
+ # map tensor values within the open interval value_range to {0,.., nbins-1},
+ # values outside the open interval will be zero or less, or nbins or more.
+ indices = math_ops.floor(nbins_float * scaled_values, name='indices')
+
+ # Clip edge cases (e.g. value = value_range[1]) or "outliers."
+ indices = math_ops.cast(
+ clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)
+
+ cdf = array_ops.zeros(nbins)
+ i = constant_op.constant(0)
+
+ def loop_cond(loop_count, _):
+ return math_ops.less(loop_count, nbins)
+
+ def loop_body(loop_count, cdf):
+ temp = math_ops.reduce_sum(
+ math_ops.cast(
+ math_ops.less_equal(indices, loop_count), dtypes.float32))
+ cdf = math_ops.add(
+ cdf,
+ array_ops.one_hot(
+ loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
+ return [loop_count + 1, cdf]
+
+ _, cdf = control_flow_ops.while_loop(
+ loop_cond, loop_body, [i, cdf], maximum_iterations=nbins)
+
+ return math_ops.div(cdf, math_ops.reduce_max(cdf))
+
+
+def determine_partitioned_axis(partitioned_variable):
+ partitioned_axis = 0
+ concatenated_variable_shape = partitioned_variable.get_shape()
+ for partition in partitioned_variable:
+ partition_shape = partition.get_shape()
+ maybe_partitioned_axis = np.less(partition_shape,
+ concatenated_variable_shape)
+ # Sanity check: make sure number of partitioned axis == 1
+ if np.count_nonzero(maybe_partitioned_axis) != 1:
+ raise ValueError('Number of partitioned axes %s not equal to 1' %
+ np.count_nonzero(maybe_partitioned_axis))
+ partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
+ return partitioned_axis
+
+
+def variable_assign(var, new_value):
+ return state_ops.assign(var, new_value, name=var.op.name + '_assign')
+
+
+def partitioned_variable_assign(partitioned_var, new_value):
+ """Assign op for partitioned variables.
+
+ Args:
+ partitioned_var: A partitioned tensorflow variable
+ new_value: Value to be assigned to the variable var
+
+ Returns:
+ A tensorflow op that groups the assign ops for each of the variable slices
+ """
+ # Determine which axis was used to partition the variable. Currently
+ # tensorflow allows partitioning variable only along 1 axis.
+ axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
+ partitioned_var)
+
+ partition_sizes = np.array(
+ [partition.get_shape()[axis] for partition in partitioned_var])
+ new_partitioned_values = array_ops.split(
+ new_value,
+ ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32),
+ axis=axis)
+ op_list = []
+ for partition in partitioned_var:
+ op_list.append(
+ variable_assign(partition, new_partitioned_values[len(op_list)]))
+ return control_flow_ops.group(
+ *op_list, name=partitioned_var.name + '_group_assign')
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
new file mode 100644
index 0000000000..10e1dd0a8e
--- /dev/null
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -0,0 +1,86 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for utility functions in pruning_utils.py."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.model_pruning.python import pruning_utils
+from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class PruningUtilsTest(test.TestCase):
+
+ def testHistogram(self):
+ width = 10
+ height = 10
+ nbins = 100
+ expected_histogram = np.full(nbins, 1.0)
+ init = init_ops.constant_initializer(np.linspace(0.0, 1.0, width * height))
+ weights = variable_scope.get_variable(
+ "weights", [width, height], initializer=init)
+ histogram = pruning_utils._histogram(
+ weights, [0, 1.0], nbins, dtype=np.float32)
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ computed_histogram = histogram.eval()
+ self.assertAllEqual(expected_histogram, computed_histogram)
+
+ def testCDF(self):
+ nbins = 5
+ weights = constant_op.constant([-1, 0, 1, 1.5, 2, 3, 4, 5, 10, 100])
+ abs_weights = math_ops.abs(weights)
+ norm_cdf = pruning_utils.compute_cdf_from_histogram(
+ abs_weights, [0.0, 5.0], nbins=nbins)
+ expected_cdf = np.array([0.1, 0.4, 0.5, 0.6, 1.0], dtype=np.float32)
+ with self.test_session() as sess:
+ variables.global_variables_initializer().run()
+ norm_cdf_val = sess.run(norm_cdf)
+ self.assertAllEqual(len(norm_cdf_val), nbins)
+ self.assertAllEqual(expected_cdf, norm_cdf_val)
+
+ def _compare_cdf(self, values):
+ abs_values = math_ops.abs(values)
+ max_value = math_ops.reduce_max(abs_values)
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
+ abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
+ cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
+ return cdf.eval(), cdf_from_histogram.eval()
+
+ def testCDFEquivalence2D(self):
+ width = 100
+ height = 100
+ weights = variable_scope.get_variable("weights", shape=[width, height])
+ cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
+ self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+
+ def testCDFEquivalence4D(self):
+ weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
+ cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
+ self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+
+
+if __name__ == "__main__":
+ test.main()