diff options
4 files changed, 114 insertions, 23 deletions
diff --git a/tensorflow/contrib/model_pruning/BUILD b/tensorflow/contrib/model_pruning/BUILD index 16ddc38f5a..e662b11be8 100644 --- a/tensorflow/contrib/model_pruning/BUILD +++ b/tensorflow/contrib/model_pruning/BUILD @@ -119,6 +119,7 @@ py_test( deps = [ ":pruning_utils", "//tensorflow/python:client_testlib", + "@absl_py//absl/testing:parameterized", ], ) diff --git a/tensorflow/contrib/model_pruning/python/pruning.py b/tensorflow/contrib/model_pruning/python/pruning.py index cd58526ed3..a81abac2fa 100644 --- a/tensorflow/contrib/model_pruning/python/pruning.py +++ b/tensorflow/contrib/model_pruning/python/pruning.py @@ -476,8 +476,8 @@ class Pruning(object): smoothed_threshold, new_mask = self._update_mask(pooled_weights, threshold) - updated_mask = pruning_utils.kronecker_product( - new_mask, array_ops.ones(self._block_dim)) + + updated_mask = pruning_utils.expand_tensor(new_mask, self._block_dim) sliced_mask = array_ops.slice( updated_mask, [0, 0], [squeezed_weights.get_shape()[0], diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index ef6c6a3f5d..b50a372e9d 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -69,7 +69,7 @@ def weight_threshold_variable(var, scope): scope: The variable scope of the variable var Returns: - a scalar threshold variable initialized to 0. + A scalar threshold variable initialized to 0. """ with variable_scope.variable_scope(scope): threshold = variable_scope.get_variable( @@ -97,6 +97,74 @@ def kronecker_product(mat1, mat2): return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) +def expand_tensor(tensor, block_dims): + """Expands a 2D tensor by replicating the tensor values. + + This is equivalent to the kronecker product of the tensor and a matrix of + ones of size block_dims. + + Example: + + tensor = [[1,2] + [3,4]] + block_dims = [2,2] + + result = [[1 1 2 2] + [1 1 2 2] + [3 3 4 4] + [3 3 4 4]] + + Args: + tensor: A 2D tensor that needs to be expanded. + block_dims: List of integers specifying the expansion factor. + + Returns: + The expanded tensor + + Raises: + ValueError: if tensor is not rank-2 or block_dims is does not have 2 + elements. + """ + if tensor.get_shape().ndims != 2: + raise ValueError('Input tensor must be rank 2') + + if len(block_dims) != 2: + raise ValueError('block_dims must have 2 elements') + + block_height, block_width = block_dims + + def _tile_rows(tensor, multiple): + """Create a new tensor by tiling the tensor along rows.""" + return array_ops.tile(tensor, [multiple, 1]) + + def _generate_indices(num_rows, block_dim): + indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32) + for k in range(block_dim): + for r in range(num_rows): + indices[k * num_rows + r] = r * block_dim + k + return indices + + def _replicate_rows(tensor, multiple): + tensor_shape = tensor.shape.as_list() + expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] + indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple)) + return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple), + expanded_shape) + + expanded_tensor = tensor + + # Expand rows by factor block_height. + if block_height > 1: + expanded_tensor = _replicate_rows(tensor, block_height) + + # Transpose and expand by factor block_width. Transpose the result. + if block_width > 1: + expanded_tensor = array_ops.transpose( + _replicate_rows(array_ops.transpose(expanded_tensor), block_width)) + + return expanded_tensor + + def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): """Return histogram of values. diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py index ccde5b4e8a..06d7f97437 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils_test.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils_test.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +from absl.testing import parameterized import numpy as np from tensorflow.contrib.model_pruning.python import pruning_utils @@ -26,6 +27,7 @@ from tensorflow.python.ops import array_ops from tensorflow.python.ops import init_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import nn_ops +from tensorflow.python.ops import random_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables from tensorflow.python.platform import test @@ -43,20 +45,6 @@ class PruningUtilsTest(test.TestCase): cdf = pruning_utils.compute_cdf(abs_values, [0.0, max_value]) self.assertAllEqual(cdf.eval(), cdf_from_histogram.eval()) - def _compare_pooling_methods(self, weights, pooling_kwargs): - with self.test_session(): - variables.global_variables_initializer().run() - pooled_weights_tf = array_ops.squeeze( - nn_ops.pool( - array_ops.reshape( - weights, - [1, weights.get_shape()[0], - weights.get_shape()[1], 1]), **pooling_kwargs)) - pooled_weights_factorized_pool = pruning_utils.factorized_pool( - weights, **pooling_kwargs) - self.assertAllClose(pooled_weights_tf.eval(), - pooled_weights_factorized_pool.eval()) - def testHistogram(self): width = 10 height = 10 @@ -95,26 +83,60 @@ class PruningUtilsTest(test.TestCase): weights = variable_scope.get_variable("weights", shape=[5, 5, 128, 128]) self._compare_cdf(weights) - def testFactorizedAvgPool(self): + +@parameterized.named_parameters( + ("1x1", [1, 1]), ("4x4", [4, 4]), ("6x6", [6, 6]), ("1x4", [1, 4]), + ("4x1", [4, 1]), ("1x8", [1, 8]), ("8x1", [8, 1])) +class PruningUtilsParameterizedTest(test.TestCase, parameterized.TestCase): + + def _compare_pooling_methods(self, weights, pooling_kwargs): + with self.test_session(): + variables.global_variables_initializer().run() + pooled_weights_tf = array_ops.squeeze( + nn_ops.pool( + array_ops.reshape( + weights, + [1, weights.get_shape()[0], + weights.get_shape()[1], 1]), **pooling_kwargs)) + pooled_weights_factorized_pool = pruning_utils.factorized_pool( + weights, **pooling_kwargs) + self.assertAllClose(pooled_weights_tf.eval(), + pooled_weights_factorized_pool.eval()) + + def _compare_expand_tensor_with_kronecker_product(self, tensor, block_dim): + with self.test_session() as session: + variables.global_variables_initializer().run() + expanded_tensor = pruning_utils.expand_tensor(tensor, block_dim) + kronecker_product = pruning_utils.kronecker_product( + tensor, array_ops.ones(block_dim)) + expanded_tensor_val, kronecker_product_val = session.run( + [expanded_tensor, kronecker_product]) + self.assertAllEqual(expanded_tensor_val, kronecker_product_val) + + def testFactorizedAvgPool(self, window_shape): weights = variable_scope.get_variable("weights", shape=[1024, 2048]) pooling_kwargs = { - "window_shape": [2, 4], + "window_shape": window_shape, "pooling_type": "AVG", - "strides": [2, 4], + "strides": window_shape, "padding": "SAME" } self._compare_pooling_methods(weights, pooling_kwargs) - def testFactorizedMaxPool(self): + def testFactorizedMaxPool(self, window_shape): weights = variable_scope.get_variable("weights", shape=[1024, 2048]) pooling_kwargs = { - "window_shape": [2, 4], + "window_shape": window_shape, "pooling_type": "MAX", - "strides": [2, 4], + "strides": window_shape, "padding": "SAME" } self._compare_pooling_methods(weights, pooling_kwargs) + def testExpandTensor(self, block_dim): + weights = random_ops.random_normal(shape=[1024, 512]) + self._compare_expand_tensor_with_kronecker_product(weights, block_dim) + if __name__ == "__main__": test.main() |