aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/conv_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/conv_utils.py')
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py166
1 files changed, 166 insertions, 0 deletions
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 5419e7ae05..3a176c3316 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import itertools
import numpy as np
from six.moves import range # pylint: disable=redefined-builtin
@@ -199,3 +200,168 @@ def convert_kernel(kernel):
no_flip = (slice(None, None), slice(None, None))
slices[-2:] = no_flip
return np.copy(kernel[slices])
+
+
+def conv_kernel_mask(input_shape, kernel_shape, strides, padding):
+ """Compute a mask representing the connectivity of a convolution operation.
+
+ Assume a convolution with given parameters is applied to an input having N
+ spatial dimensions with `input_shape = (d_in1, ..., d_inN)` to produce an
+ output with shape `(d_out1, ..., d_outN)`. This method returns a boolean array
+ of shape `(d_in1, ..., d_inN, d_out1, ..., d_outN)` with `True` entries
+ indicating pairs of input and output locations that are connected by a weight.
+
+ Example:
+ ```python
+ >>> input_shape = (4,)
+ >>> kernel_shape = (2,)
+ >>> strides = (1,)
+ >>> padding = "valid"
+ >>> conv_kernel_mask(input_shape, kernel_shape, strides, padding)
+ array([[ True, False, False],
+ [ True, True, False],
+ [False, True, True],
+ [False, False, True]], dtype=bool)
+ ```
+ where rows and columns correspond to inputs and outputs respectively.
+
+
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ A boolean 2N-D `np.ndarray` of shape
+ `(d_in1, ..., d_inN, d_out1, ..., d_outN)`, where `(d_out1, ..., d_outN)`
+ is the spatial shape of the output. `True` entries in the mask represent
+ pairs of input-output locations that are connected by a weight.
+
+ Raises:
+ ValueError: if `input_shape`, `kernel_shape` and `strides` don't have the
+ same number of dimensions.
+ NotImplementedError: if `padding` is not in {`"same"`, `"valid"`}.
+ """
+ if padding not in {'same', 'valid'}:
+ raise NotImplementedError('Padding type %s not supported. '
+ 'Only "valid" and "same" '
+ 'are implemented.' % padding)
+
+ in_dims = len(input_shape)
+ if isinstance(kernel_shape, int):
+ kernel_shape = (kernel_shape,) * in_dims
+ if isinstance(strides, int):
+ strides = (strides,) * in_dims
+
+ kernel_dims = len(kernel_shape)
+ stride_dims = len(strides)
+ if kernel_dims != in_dims or stride_dims != in_dims:
+ raise ValueError('Number of strides, input and kernel dimensions must all '
+ 'match. Received: %d, %d, %d.' %
+ (stride_dims, in_dims, kernel_dims))
+
+ output_shape = conv_output_shape(input_shape, kernel_shape, strides, padding)
+
+ mask_shape = input_shape + output_shape
+ mask = np.zeros(mask_shape, np.bool)
+
+ output_axes_ticks = [range(dim) for dim in output_shape]
+ for output_position in itertools.product(*output_axes_ticks):
+ input_axes_ticks = conv_connected_inputs(input_shape,
+ kernel_shape,
+ output_position,
+ strides,
+ padding)
+ for input_position in itertools.product(*input_axes_ticks):
+ mask[input_position + output_position] = True
+
+ return mask
+
+
+def conv_connected_inputs(input_shape,
+ kernel_shape,
+ output_position,
+ strides,
+ padding):
+ """Return locations of the input connected to an output position.
+
+ Assume a convolution with given parameters is applied to an input having N
+ spatial dimensions with `input_shape = (d_in1, ..., d_inN)`. This method
+ returns N ranges specifying the input region that was convolved with the
+ kernel to produce the output at position
+ `output_position = (p_out1, ..., p_outN)`.
+
+ Example:
+ ```python
+ >>> input_shape = (4, 4)
+ >>> kernel_shape = (2, 1)
+ >>> output_position = (1, 1)
+ >>> strides = (1, 1)
+ >>> padding = "valid"
+ >>> conv_connected_inputs(input_shape, kernel_shape, output_position,
+ >>> strides, padding)
+ [xrange(1, 3), xrange(1, 2)]
+ ```
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ output_position: tuple of size N: `(p_out1, ..., p_outN)`,
+ a single position in the output of the convolution.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ N ranges `[[p_in_left1, ..., p_in_right1], ...,
+ [p_in_leftN, ..., p_in_rightN]]` specifying the region in the
+ input connected to output_position.
+ """
+ ranges = []
+
+ ndims = len(input_shape)
+ for d in range(ndims):
+ left_shift = int(kernel_shape[d] / 2)
+ right_shift = kernel_shape[d] - left_shift
+
+ center = output_position[d] * strides[d]
+
+ if padding == 'valid':
+ center += left_shift
+
+ start = max(0, center - left_shift)
+ end = min(input_shape[d], center + right_shift)
+
+ ranges.append(range(start, end))
+
+ return ranges
+
+
+def conv_output_shape(input_shape, kernel_shape, strides, padding):
+ """Return the output shape of an N-D convolution.
+
+ Forces dimensions where input is empty (size 0) to remain empty.
+
+ Args:
+ input_shape: tuple of size N: `(d_in1, ..., d_inN)`,
+ spatial shape of the input.
+ kernel_shape: tuple of size N, spatial shape of the convolutional kernel
+ / receptive field.
+ strides: tuple of size N, strides along each spatial dimension.
+ padding: type of padding, string `"same"` or `"valid"`.
+
+ Returns:
+ tuple of size N: `(d_out1, ..., d_outN)`, spatial shape of the output.
+ """
+ dims = range(len(kernel_shape))
+ output_shape = [conv_output_length(input_shape[d],
+ kernel_shape[d],
+ padding,
+ strides[d])
+ for d in dims]
+ output_shape = tuple([0 if input_shape[d] == 0 else output_shape[d]
+ for d in dims])
+ return output_shape