diff options
Diffstat (limited to 'tensorflow/contrib/model_pruning/python/pruning_utils.py')
-rw-r--r-- | tensorflow/contrib/model_pruning/python/pruning_utils.py | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/contrib/model_pruning/python/pruning_utils.py b/tensorflow/contrib/model_pruning/python/pruning_utils.py index ef6c6a3f5d..b50a372e9d 100644 --- a/tensorflow/contrib/model_pruning/python/pruning_utils.py +++ b/tensorflow/contrib/model_pruning/python/pruning_utils.py @@ -69,7 +69,7 @@ def weight_threshold_variable(var, scope): scope: The variable scope of the variable var Returns: - a scalar threshold variable initialized to 0. + A scalar threshold variable initialized to 0. """ with variable_scope.variable_scope(scope): threshold = variable_scope.get_variable( @@ -97,6 +97,74 @@ def kronecker_product(mat1, mat2): return array_ops.reshape(mat1_rsh * mat2_rsh, [m1 * m2, n1 * n2]) +def expand_tensor(tensor, block_dims): + """Expands a 2D tensor by replicating the tensor values. + + This is equivalent to the kronecker product of the tensor and a matrix of + ones of size block_dims. + + Example: + + tensor = [[1,2] + [3,4]] + block_dims = [2,2] + + result = [[1 1 2 2] + [1 1 2 2] + [3 3 4 4] + [3 3 4 4]] + + Args: + tensor: A 2D tensor that needs to be expanded. + block_dims: List of integers specifying the expansion factor. + + Returns: + The expanded tensor + + Raises: + ValueError: if tensor is not rank-2 or block_dims is does not have 2 + elements. + """ + if tensor.get_shape().ndims != 2: + raise ValueError('Input tensor must be rank 2') + + if len(block_dims) != 2: + raise ValueError('block_dims must have 2 elements') + + block_height, block_width = block_dims + + def _tile_rows(tensor, multiple): + """Create a new tensor by tiling the tensor along rows.""" + return array_ops.tile(tensor, [multiple, 1]) + + def _generate_indices(num_rows, block_dim): + indices = np.zeros(shape=[num_rows * block_dim, 1], dtype=np.int32) + for k in range(block_dim): + for r in range(num_rows): + indices[k * num_rows + r] = r * block_dim + k + return indices + + def _replicate_rows(tensor, multiple): + tensor_shape = tensor.shape.as_list() + expanded_shape = [tensor_shape[0] * multiple, tensor_shape[1]] + indices = constant_op.constant(_generate_indices(tensor_shape[0], multiple)) + return array_ops.scatter_nd(indices, _tile_rows(tensor, multiple), + expanded_shape) + + expanded_tensor = tensor + + # Expand rows by factor block_height. + if block_height > 1: + expanded_tensor = _replicate_rows(tensor, block_height) + + # Transpose and expand by factor block_width. Transpose the result. + if block_width > 1: + expanded_tensor = array_ops.transpose( + _replicate_rows(array_ops.transpose(expanded_tensor), block_width)) + + return expanded_tensor + + def _histogram(values, value_range, nbins=100, dtype=dtypes.int32, name=None): """Return histogram of values. |