aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning/python/pruning_utils.py
blob: ef6c6a3f5d7aa2980dfd4e59d450ec827eb68f0a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for adding pruning related ops to the graph.
"""
# pylint: disable=missing-docstring
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np

from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope

_NBINS = 256


def weight_mask_variable(var, scope):
  """Create a mask for the weights.

  This function adds a variable 'mask' to the graph.

  Args:
    var: the weight variable that needs to be masked
    scope: The variable scope of the variable var

  Returns:
    the mask variable of the same size and shape as var, initialized to all 1s.
  """
  with variable_scope.variable_scope(scope):
    mask = variable_scope.get_variable(
        'mask',
        var.get_shape(),
        initializer=init_ops.ones_initializer(),
        trainable=False,
        dtype=var.dtype)
  return mask


def weight_threshold_variable(var, scope):
  """Create a scalar threshold for the weights.

  This function adds a variable
  'threshold' to the graph.

  Args:
    var: The weight variable that needs to be masked
    scope: The variable scope of the variable var

  Returns:
    a scalar threshold variable initialized to 0.
  """
  with variable_scope.variable_scope(scope):
    threshold = variable_scope.get_variable(
        'threshold', [],
        initializer=init_ops.zeros_initializer(),
        trainable=False,
        dtype=var.dtype)
    return threshold


def kronecker_product(mat1, mat2):
  """Computes the Kronecker product of two matrices mat1 and mat2.

  Args:
    mat1: A matrix of size m x n
    mat2: A matrix of size p x q
  Returns:
    Kronecker product of matrices mat1 and mat2 of size mp x nq
  """

  m1, n1 = mat1.get_shape().as_list()
  mat1_rsh = array_ops.reshape(mat1, [m1, 1, n1, 1])
  m2, n2 = mat2.get_shape().as_list()
  mat2_rsh = array_ops.reshape(mat2, [1, m2, 1, n2])
  return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])


def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
  """Return histogram of values.

  Given the tensor `values`, this operation returns a rank 1 histogram counting
  the number of entries in `values` that fell into every bin.  The bins are
  equal width and determined by the arguments `value_range` and `nbins`.

  Args:
    values:  Numeric `Tensor`.
    value_range:  Shape [2] `Tensor` of same `dtype` as `values`.
      values <= value_range[0] will be mapped to hist[0],
      values >= value_range[1] will be mapped to hist[-1].
    nbins:  Scalar `int32 Tensor`.  Number of histogram bins.
    dtype:  dtype for returned histogram.
    name:  A name for this operation (defaults to 'histogram').

  Returns:
    A 1-D `Tensor` holding histogram of values.

  """
  with ops.name_scope(name, 'histogram', [values, value_range, nbins]) as scope:
    values = ops.convert_to_tensor(values, name='values')
    values = array_ops.reshape(values, [-1])
    value_range = ops.convert_to_tensor(value_range, name='value_range')
    nbins_float = np.float32(nbins)

    # Map tensor values that fall within value_range to [0, 1].
    scaled_values = math_ops.truediv(
        values - value_range[0],
        value_range[1] - value_range[0],
        name='scaled_values')

    # map tensor values within the open interval value_range to {0,.., nbins-1},
    # values outside the open interval will be zero or less, or nbins or more.
    indices = math_ops.floor(nbins_float * scaled_values, name='indices')

    # Clip edge cases (e.g. value = value_range[1]) or "outliers."
    indices = math_ops.cast(
        clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)

    return math_ops.unsorted_segment_sum(
        array_ops.ones_like(indices, dtype=dtype), indices, nbins, name=scope)


def compute_cdf_from_histogram(values, value_range, **kwargs):
  """Returns the normalized cumulative distribution of the given values tensor.

  Computes the histogram and uses tf.cumsum to arrive at cdf

  Args:
    values:  Numeric `Tensor`.
    value_range:  Shape [2] `Tensor` of same `dtype` as `values`.
    **kwargs: keyword arguments: nbins, name

  Returns:
    A 1-D `Tensor` holding normalized cdf of values.

  """
  nbins = kwargs.get('nbins', _NBINS)
  name = kwargs.get('name', None)
  with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
    histogram = _histogram(
        values, value_range, dtype=dtypes.float32, nbins=nbins)
    cdf = math_ops.cumsum(histogram)
    return math_ops.div(cdf, math_ops.reduce_max(cdf))


def compute_cdf(values, value_range, **kwargs):
  """Returns the normalized cumulative distribution of the given values tensor.

  Uses tf.while_loop to directly compute the cdf of the values. Number of bins
  for histogram is fixed at _NBINS=255

  Args:
    values:  Numeric `Tensor`.
    value_range:  Shape [2] `Tensor` of same `dtype` as `values`
    **kwargs: keyword arguments: name

  Returns:
    A 1-D `Tensor` holding normalized cdf of values.

  """
  nbins = _NBINS
  name = kwargs.get('name', None)
  with ops.name_scope(name, 'cdf', [values, value_range, nbins]):
    values = ops.convert_to_tensor(values, name='values')
    value_range = ops.convert_to_tensor(value_range, name='value_range')
    nbins_float = np.float32(nbins)

    # Map tensor values that fall within value_range to [0, 1].
    scaled_values = math_ops.truediv(
        values - value_range[0],
        value_range[1] - value_range[0],
        name='scaled_values')

    # map tensor values within the open interval value_range to {0,.., nbins-1},
    # values outside the open interval will be zero or less, or nbins or more.
    indices = math_ops.floor(nbins_float * scaled_values, name='indices')

    # Clip edge cases (e.g. value = value_range[1]) or "outliers."
    indices = math_ops.cast(
        clip_ops.clip_by_value(indices, 0, nbins_float - 1), dtypes.int32)

    cdf = array_ops.zeros(nbins)
    i = constant_op.constant(0)

    def loop_cond(loop_count, _):
      return math_ops.less(loop_count, nbins)

    def loop_body(loop_count, cdf):
      temp = math_ops.reduce_sum(
          math_ops.cast(
              math_ops.less_equal(indices, loop_count), dtypes.float32))
      cdf = math_ops.add(
          cdf,
          array_ops.one_hot(
              loop_count, depth=_NBINS, on_value=temp, off_value=0.0))
      return [loop_count + 1, cdf]

    _, cdf = control_flow_ops.while_loop(
        loop_cond, loop_body, [i, cdf], maximum_iterations=nbins)

    return math_ops.div(cdf, math_ops.reduce_max(cdf))


def factorized_pool(input_tensor,
                    window_shape,
                    pooling_type,
                    strides,
                    padding,
                    name=None):
  """Performs m x n pooling through a combination of 1xm and 1xn pooling.

  Args:
    input_tensor: Input tensor. Must be rank 2
    window_shape: Pooling window shape
    pooling_type: Either 'MAX' or 'AVG'
    strides: The stride of the pooling window
    padding: 'SAME' or 'VALID'.
    name: Name of the op

  Returns:
    A rank 2 tensor containing the pooled output

  Raises:
    ValueError: if the input tensor is not rank 2
  """
  if input_tensor.get_shape().ndims != 2:
    raise ValueError('factorized_pool() accepts tensors of rank 2 only')

  [height, width] = input_tensor.get_shape()
  with ops.name_scope(name, 'factorized_pool'):
    input_tensor_aligned = array_ops.reshape(
        input_tensor, [1, 1, height, width],
        name=input_tensor.op.name + '_aligned')

    height_pooling = nn_ops.pool(
        input_tensor_aligned,
        window_shape=[1, window_shape[0]],
        pooling_type=pooling_type,
        strides=[1, strides[0]],
        padding=padding)
    swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])

    width_pooling = nn_ops.pool(
        swap_height_width,
        window_shape=[1, window_shape[1]],
        pooling_type=pooling_type,
        strides=[1, strides[1]],
        padding=padding)

  return array_ops.squeeze(
      array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))


def determine_partitioned_axis(partitioned_variable):
  partitioned_axis = 0
  concatenated_variable_shape = partitioned_variable.get_shape()
  for partition in partitioned_variable:
    partition_shape = partition.get_shape()
    maybe_partitioned_axis = np.less(partition_shape,
                                     concatenated_variable_shape)
    # Sanity check: make sure number of partitioned axis == 1
    if np.count_nonzero(maybe_partitioned_axis) != 1:
      raise ValueError('Number of partitioned axes %s not equal to 1' %
                       np.count_nonzero(maybe_partitioned_axis))
    partitioned_axis = np.where(maybe_partitioned_axis)[0][0]
  return partitioned_axis


def variable_assign(var, new_value):
  return state_ops.assign(var, new_value, name=var.op.name + '_assign')


def partitioned_variable_assign(partitioned_var, new_value):
  """Assign op for partitioned variables.

  Args:
    partitioned_var: A partitioned tensorflow variable
    new_value: Value to be assigned to the variable var

  Returns:
    A tensorflow op that groups the assign ops for each of the variable slices
  """
  # Determine which axis was used to partition the variable. Currently
  # tensorflow allows partitioning variable only along 1 axis.
  axis = 0 if len(partitioned_var) == 1 else determine_partitioned_axis(
      partitioned_var)

  partition_sizes = np.array(
      [partition.get_shape()[axis] for partition in partitioned_var])
  new_partitioned_values = array_ops.split(
      new_value,
      ops.convert_to_tensor(partition_sizes, dtype=dtypes.int32),
      axis=axis)
  op_list = []
  for partition in partitioned_var:
    op_list.append(
        variable_assign(partition, new_partitioned_values[len(op_list)]))
  return control_flow_ops.group(
      *op_list, name=partitioned_var.name + '_group_assign')