aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/model_pruning
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-07 10:49:26 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 16:37:48 -0700
commit9ba26ca0d59989592051fdb5c7a2caabe4f399f3 (patch)
treea7e66fce409d3647ab6d664b33248639a6b8bf9d /tensorflow/contrib/model_pruning
parentb2888c66e67d584756bb50850ae77acede7ba8bf (diff)
Extend block sparsity support for TPUs
PiperOrigin-RevId: 195685740
Diffstat (limited to 'tensorflow/contrib/model_pruning')
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning.py30
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils.py51
-rw-r--r--tensorflow/contrib/model_pruning/python/pruning_utils_test.py62
3 files changed, 116 insertions, 27 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py
index ea6032e588..4b7af18b33 100644
--- a/tensorflow/contrib/model_pruning/python/pruning.py
+++ b/tensorflow/contrib/model_pruning/python/pruning.py
@@ -396,14 +396,19 @@ class Pruning(object):
self._block_pooling_function)
with ops.name_scope(weights.op.name + '_pruning_ops'):
- abs_weights = math_ops.abs(
- array_ops.reshape(weights, [
- 1,
- squeezed_weights.get_shape()[0],
- squeezed_weights.get_shape()[1], 1
- ]))
+ abs_weights = math_ops.abs(squeezed_weights)
+
pool_window = [self._block_dim[0], self._block_dim[1]]
- pooled_weights = nn_ops.pool(
+ pool_fn = pruning_utils.factorized_pool
+
+ if not self._spec.use_tpu:
+ pool_fn = nn_ops.pool
+ abs_weights = array_ops.reshape(
+ abs_weights,
+ [1, abs_weights.get_shape()[0],
+ abs_weights.get_shape()[1], 1])
+
+ pooled_weights = pool_fn(
abs_weights,
window_shape=pool_window,
pooling_type=self._block_pooling_function,
@@ -411,19 +416,18 @@ class Pruning(object):
padding='SAME',
name=weights.op.name + '_pooled')
+ if pooled_weights.get_shape().ndims != 2:
+ pooled_weights = array_ops.squeeze(pooled_weights)
+
smoothed_threshold, new_mask = self._update_mask(pooled_weights,
threshold)
-
- reshaped_mask = array_ops.reshape(
- new_mask,
- [pooled_weights.get_shape()[1],
- pooled_weights.get_shape()[2]])
updated_mask = pruning_utils.kronecker_product(
- reshaped_mask, array_ops.ones(self._block_dim))
+ new_mask, array_ops.ones(self._block_dim))
sliced_mask = array_ops.slice(
updated_mask, [0, 0],
[squeezed_weights.get_shape()[0],
squeezed_weights.get_shape()[1]])
+
return smoothed_threshold, array_ops.reshape(sliced_mask,
array_ops.shape(weights))
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py
index 56d3dcef20..ef6c6a3f5d 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import clip_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
@@ -221,6 +222,56 @@ def compute_cdf(values, value_range, **kwargs):
return math_ops.div(cdf, math_ops.reduce_max(cdf))
+def factorized_pool(input_tensor,
+ window_shape,
+ pooling_type,
+ strides,
+ padding,
+ name=None):
+ """Performs m x n pooling through a combination of 1xm and 1xn pooling.
+
+ Args:
+ input_tensor: Input tensor. Must be rank 2
+ window_shape: Pooling window shape
+ pooling_type: Either 'MAX' or 'AVG'
+ strides: The stride of the pooling window
+ padding: 'SAME' or 'VALID'.
+ name: Name of the op
+
+ Returns:
+ A rank 2 tensor containing the pooled output
+
+ Raises:
+ ValueError: if the input tensor is not rank 2
+ """
+ if input_tensor.get_shape().ndims != 2:
+ raise ValueError('factorized_pool() accepts tensors of rank 2 only')
+
+ [height, width] = input_tensor.get_shape()
+ with ops.name_scope(name, 'factorized_pool'):
+ input_tensor_aligned = array_ops.reshape(
+ input_tensor, [1, 1, height, width],
+ name=input_tensor.op.name + '_aligned')
+
+ height_pooling = nn_ops.pool(
+ input_tensor_aligned,
+ window_shape=[1, window_shape[0]],
+ pooling_type=pooling_type,
+ strides=[1, strides[0]],
+ padding=padding)
+ swap_height_width = array_ops.transpose(height_pooling, perm=[0, 1, 3, 2])
+
+ width_pooling = nn_ops.pool(
+ swap_height_width,
+ window_shape=[1, window_shape[1]],
+ pooling_type=pooling_type,
+ strides=[1, strides[1]],
+ padding=padding)
+
+ return array_ops.squeeze(
+ array_ops.transpose(width_pooling, perm=[0, 1, 3, 2]))
+
+
def determine_partitioned_axis(partitioned_variable):
partitioned_axis = 0
concatenated_variable_shape = partitioned_variable.get_shape()
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
index 10e1dd0a8e..ccde5b4e8a 100644
--- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
+++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py
@@ -22,8 +22,10 @@ import numpy as np
from tensorflow.contrib.model_pruning.python import pruning_utils
from tensorflow.python.framework import constant_op
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -31,6 +33,30 @@ from tensorflow.python.platform import test
class PruningUtilsTest(test.TestCase):
+ def _compare_cdf(self, values):
+ abs_values = math_ops.abs(values)
+ max_value = math_ops.reduce_max(abs_values)
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
+ abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
+ cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
+ self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval())
+
+ def _compare_pooling_methods(self, weights, pooling_kwargs):
+ with self.test_session():
+ variables.global_variables_initializer().run()
+ pooled_weights_tf = array_ops.squeeze(
+ nn_ops.pool(
+ array_ops.reshape(
+ weights,
+ [1, weights.get_shape()[0],
+ weights.get_shape()[1], 1]), **pooling_kwargs))
+ pooled_weights_factorized_pool = pruning_utils.factorized_pool(
+ weights, **pooling_kwargs)
+ self.assertAllClose(pooled_weights_tf.eval(),
+ pooled_weights_factorized_pool.eval())
+
def testHistogram(self):
width = 10
height = 10
@@ -59,27 +85,35 @@ class PruningUtilsTest(test.TestCase):
self.assertAllEqual(len(norm_cdf_val), nbins)
self.assertAllEqual(expected_cdf, norm_cdf_val)
- def _compare_cdf(self, values):
- abs_values = math_ops.abs(values)
- max_value = math_ops.reduce_max(abs_values)
- with self.test_session():
- variables.global_variables_initializer().run()
- cdf_from_histogram = pruning_utils.compute_cdf_from_histogram(
- abs_values, [0.0, max_value], nbins=pruning_utils._NBINS)
- cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value])
- return cdf.eval(), cdf_from_histogram.eval()
-
def testCDFEquivalence2D(self):
width = 100
height = 100
weights = variable_scope.get_variable("weights", shape=[width, height])
- cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
- self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+ self._compare_cdf(weights)
def testCDFEquivalence4D(self):
weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128])
- cdf_val, cdf_from_histogram_val = self._compare_cdf(weights)
- self.assertAllEqual(cdf_val, cdf_from_histogram_val)
+ self._compare_cdf(weights)
+
+ def testFactorizedAvgPool(self):
+ weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+ pooling_kwargs = {
+ "window_shape": [2, 4],
+ "pooling_type": "AVG",
+ "strides": [2, 4],
+ "padding": "SAME"
+ }
+ self._compare_pooling_methods(weights, pooling_kwargs)
+
+ def testFactorizedMaxPool(self):
+ weights = variable_scope.get_variable("weights", shape=[1024, 2048])
+ pooling_kwargs = {
+ "window_shape": [2, 4],
+ "pooling_type": "MAX",
+ "strides": [2, 4],
+ "padding": "SAME"
+ }
+ self._compare_pooling_methods(weights, pooling_kwargs)
if __name__ == "__main__":