aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar Suyog Gupta <suyoggupta@google.com>2018-08-20 09:39:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-20 09:43:46 -0700
commitfd8df1ce215db6a19fc1623ba2b039781fd1458f (patch)
treec292f220e31904630a5d824b6329c4bf9840de25 /tensorflow/contrib/model_pruning
parentb7127df9da79b8c3c017f5de1b6f571eb3ff487b (diff)
Optimization: implementation of mask expansion for block sparsity that does not depend on kronecker product
PiperOrigin-RevId: 209432310
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/BUILD1
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py4
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py70
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py62
4 files changed, 114 insertions, 23 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD
index 16ddc38f5a..e662b11be8 100644
--- a/tensorflow/contrib/model_pruning/BUILD
+++ b/tensorflow/contrib/model_pruning/BUILD
@@ -119,6 +119,7 @@ py_test(
deps = [
":pruning_utils",
"//tensorflow/python:client_testlib",
+ "@absl_py//absl/testing:parameterized",
],
)
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index cd58526ed3..a81abac2fa 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -476,8 +476,8 @@ class Pruning(object):
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
- updated_mask = pruning_utils.kronecker_product(
- new_mask, array_ops.ones(self._block_dim))
+
+ updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim)
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index ef6c6a3f5d..b50a372e9d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -69,7 +69,7 @@ def weight_threshold_variable(var, scope):
scope: The variable scope of the variable var
Returns:
- a scalar threshold variable initialized to 0.
+ A scalar threshold variable initialized to 0.
"""
with variable_scope.variable_scope(scope):
threshold = variable_scope.get_variable(
@@ -97,6 +97,74 @@ def kronecker_product(mat1, mat2):
return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2])
+def expand_tensor(tensor, block_dims):
+ """Expands a 2D tensor by replicating the tensor values.
+
+ This is equivalent to the kronecker product of the tensor and a matrix of
+ ones of size block_dims.
+
+ Example:
+
+ tensor = [[1,2]
+ [3,4]]
+ block_dims = [2,2]
+
+ result = [[1 1 2 2]
+ [1 1 2 2]
+ [3 3 4 4]
+ [3 3 4 4]]
+
+ Args:
+ tensor: A 2D tensor that needs to be expanded.
+ block_dims: List of integers specifying the expansion factor.
+
+ Returns:
+ The expanded tensor
+
+ Raises:
+ ValueError: if tensor is not rank-2 or block_dims is does not have 2
+ elements.
+ """
+ if tensor.get_shape().ndims != 2:
+ raise ValueError('Input tensor must be rank 2')
+
+ if len(block_dims) != 2:
+ raise ValueError('block_dims must have 2 elements')
+
+ block_height, block_width = block_dims
+
+ def _tile_rows(tensor, multiple):
+ """Create a new tensor by tiling the tensor along rows."""
+ return array_ops.tile(tensor, [multiple, 1])
+
+ def _generate_indices(num_rows, block_dim):
+ indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32)
+ for k in range(block_dim):
+ for r in range(num_rows):
+ indices[k * num_rows + r] = r * block_dim + k
+ return indices
+
+ def _replicate_rows(tensor, multiple):
+ tensor_shape = tensor.shape.as_list()
+ expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]]
+ indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple))
+ return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple),
+ expanded_shape)
+
+ expanded_tensor = tensor
+
+ # Expand rows by factor block_height.
+ if block_height > 1:
+ expanded_tensor = _replicate_rows(tensor, block_height)
+
+ # Transpose and expand by factor block_width. Transpose the result.
+ if block_width > 1:
+ expanded_tensor = array_ops.transpose(
+ _replicate_rows(array_ops.transpose(expanded_tensor), block_width))
+
+ return expanded_tensor
+
+
def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None):
"""Return histogram of values.
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index ccde5b4e8a..06d7f97437 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
@@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -43,20 +45,6 @@ class PruningUtilsTest(test.TestCase):
cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
- def _compare_pooling_methods(self, weights, pooling_kwargs):
- with self.test_session():
- variables.global_variables_initializer().run()
- pooled_weights_tf = array_ops.squeeze(
- nn_ops.pool(
- array_ops.reshape(
- weights,
- [1, weights.get_shape()[0],
- weights.get_shape()[1], 1]), **pooling_kwargs))
- pooled_weights_factorized_pool = pruning_utils.factorized_pool(
- weights, **pooling_kwargs)
- self.assertAllClose(pooled_weights_tf.eval(),
- pooled_weights_factorized_pool.eval())
-
def testHistogram(self):
width = 10
height = 10
@@ -95,26 +83,60 @@ class PruningUtilsTest(test.TestCase):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
self._compare_cdf(weights)
- def testFactorizedAvgPool(self):
+
+@parameterized.named_parameters(
+ ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]),
+ ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1]))
+class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase):
+
+ def _compare_pooling_methods(self, weights, pooling_kwargs):
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ pooled_weights_tf = array_ops.squeeze(
+ nn_ops.pool(
+ array_ops.reshape(
+ weights,
+ [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]), **pooling_kwargs))
+ pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+ weights, **pooling_kwargs)
+ self.assertAllClose(pooled_weights_tf.eval(),
+ pooled_weights_factorized_pool.eval())
+
+ def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim):
+ with self.test_session() as session:
+ variables.global_variables_initializer().run()
+ expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim)
+ kronecker_product = pruning_utils.kronecker_product(
+ tensor, array_ops.ones(block_dim))
+ expanded_tensor_val, kronecker_product_val = session.run(
+ [expanded_tensor, kronecker_product])
+ self.assertAllEqual(expanded_tensor_val, kronecker_product_val)
+
+ def testFactorizedAvgPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "AVG",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
- def testFactorizedMaxPool(self):
+ def testFactorizedMaxPool(self, window_shape):
weights = variable_scope.get_variable("weights", shape=[1024, 2048])
pooling_kwargs = {
- "window_shape": [2, 4],
+ "window_shape": window_shape,
"pooling_type": "MAX",
- "strides": [2, 4],
+ "strides": window_shape,
"padding": "SAME"
}
self._compare_pooling_methods(weights, pooling_kwargs)
+ def testExpandTensor(self, block_dim):
+ weights = random_ops.random_normal(shape=[1024, 512])
+ self._compare_expand_tensor_with_kronecker_product(weights, block_dim)
+
if __name__ == "__main__":
test.main()