diff options
Diffstat (limited to 'tensorflow/python/keras/utils/conv_utils.py')
-rw-r--r-- | tensorflow/python/keras/utils/conv_utils.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py index 5419e7ae05..3a176c3316 100644 --- a/tensorflow/python/keras/utils/conv_utils.py +++ b/tensorflow/python/keras/utils/conv_utils.py @@ -18,6 +18,7 @@ from __future__ import absolute_import from __future__ import division from __future__ import print_function +import itertools import numpy as np from six.moves import range # pylint: disable=redefined-builtin @@ -199,3 +200,168 @@ def convert_kernel(kernel): no_flip = (slice(None, None), slice(None, None)) slices[-2:] = no_flip return np.copy(kernel[slices]) + + +def conv_kernel_mask(input_shape, kernel_shape, strides, padding): + """Compute a mask representing the connectivity of a convolution operation. + + Assume a convolution with given parameters is applied to an input having N + spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an + output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array + of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries + indicating pairs of input and output locations that are connected by a weight. + + Example: + ```python + >>> input_shape = (4,) + >>> kernel_shape = (2,) + >>> strides = (1,) + >>> padding = "valid" + >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding) + array([[ True, False, False], + [ True, True, False], + [False, True, True], + [False, False, True]], dtype=bool) + ``` + where rows and columns correspond to inputs and outputs respectively. + + + Args: + input_shape: tuple of size N: `(d_in1, ..., d_inN)`, + spatial shape of the input. + kernel_shape: tuple of size N, spatial shape of the convolutional kernel + / receptive field. + strides: tuple of size N, strides along each spatial dimension. + padding: type of padding, string `"same"` or `"valid"`. + + Returns: + A boolean 2N-D `np.ndarray` of shape + `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)` + is the spatial shape of the output. `True` entries in the mask represent + pairs of input-output locations that are connected by a weight. + + Raises: + ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the + same number of dimensions. + NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}. + """ + if padding not in {'same', 'valid'}: + raise NotImplementedError('Padding type %s not supported. ' + 'Only "valid" and "same" ' + 'are implemented.' % padding) + + in_dims = len(input_shape) + if isinstance(kernel_shape, int): + kernel_shape = (kernel_shape,) * in_dims + if isinstance(strides, int): + strides = (strides,) * in_dims + + kernel_dims = len(kernel_shape) + stride_dims = len(strides) + if kernel_dims != in_dims or stride_dims != in_dims: + raise ValueError('Number of strides, input and kernel dimensions must all ' + 'match. Received: %d, %d, %d.' % + (stride_dims, in_dims, kernel_dims)) + + output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding) + + mask_shape = input_shape + output_shape + mask = np.zeros(mask_shape, np.bool) + + output_axes_ticks = [range(dim) for dim in output_shape] + for output_position in itertools.product(*output_axes_ticks): + input_axes_ticks = conv_connected_inputs(input_shape, + kernel_shape, + output_position, + strides, + padding) + for input_position in itertools.product(*input_axes_ticks): + mask[input_position + output_position] = True + + return mask + + +def conv_connected_inputs(input_shape, + kernel_shape, + output_position, + strides, + padding): + """Return locations of the input connected to an output position. + + Assume a convolution with given parameters is applied to an input having N + spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method + returns N ranges specifying the input region that was convolved with the + kernel to produce the output at position + `output_position = (p_out1, ..., p_outN)`. + + Example: + ```python + >>> input_shape = (4, 4) + >>> kernel_shape = (2, 1) + >>> output_position = (1, 1) + >>> strides = (1, 1) + >>> padding = "valid" + >>> conv_connected_inputs(input_shape, kernel_shape, output_position, + >>> strides, padding) + [xrange(1, 3), xrange(1, 2)] + ``` + Args: + input_shape: tuple of size N: `(d_in1, ..., d_inN)`, + spatial shape of the input. + kernel_shape: tuple of size N, spatial shape of the convolutional kernel + / receptive field. + output_position: tuple of size N: `(p_out1, ..., p_outN)`, + a single position in the output of the convolution. + strides: tuple of size N, strides along each spatial dimension. + padding: type of padding, string `"same"` or `"valid"`. + + Returns: + N ranges `[[p_in_left1, ..., p_in_right1], ..., + [p_in_leftN, ..., p_in_rightN]]` specifying the region in the + input connected to output_position. + """ + ranges = [] + + ndims = len(input_shape) + for d in range(ndims): + left_shift = int(kernel_shape[d] / 2) + right_shift = kernel_shape[d] - left_shift + + center = output_position[d] * strides[d] + + if padding == 'valid': + center += left_shift + + start = max(0, center - left_shift) + end = min(input_shape[d], center + right_shift) + + ranges.append(range(start, end)) + + return ranges + + +def conv_output_shape(input_shape, kernel_shape, strides, padding): + """Return the output shape of an N-D convolution. + + Forces dimensions where input is empty (size 0) to remain empty. + + Args: + input_shape: tuple of size N: `(d_in1, ..., d_inN)`, + spatial shape of the input. + kernel_shape: tuple of size N, spatial shape of the convolutional kernel + / receptive field. + strides: tuple of size N, strides along each spatial dimension. + padding: type of padding, string `"same"` or `"valid"`. + + Returns: + tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output. + """ + dims = range(len(kernel_shape)) + output_shape = [conv_output_length(input_shape[d], + kernel_shape[d], + padding, + strides[d]) + for d in dims] + output_shape = tuple([0 if input_shape[d] == 0 else output_shape[d] + for d in dims]) + return output_shape |