aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/layers/convolutional.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/layers/convolutional.py')
-rw-r--r--tensorflow/python/layers/convolutional.py274
1 files changed, 255 insertions, 19 deletions
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index b2fe9feb44..938161f426 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -972,7 +972,7 @@ def separable_conv2d(inputs,
class Conv2DTranspose(Conv2D):
- """Transposed convolution layer (sometimes called Deconvolution).
+ """Transposed 2D convolution layer (sometimes called 2D Deconvolution).
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
@@ -1086,19 +1086,9 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
- def get_deconv_dim(dim_size, stride_size, kernel_size, padding):
- if isinstance(dim_size, ops.Tensor):
- dim_size = math_ops.multiply(dim_size, stride_size)
- elif dim_size is not None:
- dim_size *= stride_size
-
- if padding == 'valid' and dim_size is not None:
- dim_size += max(kernel_size - stride_size, 0)
- return dim_size
-
# Infer the dynamic output shape:
- out_height = get_deconv_dim(height, stride_h, kernel_h, self.padding)
- out_width = get_deconv_dim(width, stride_w, kernel_w, self.padding)
+ out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
+ out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
@@ -1119,10 +1109,10 @@ class Conv2DTranspose(Conv2D):
# Infer the static output shape:
out_shape = inputs.get_shape().as_list()
out_shape[c_axis] = self.filters
- out_shape[h_axis] = get_deconv_dim(
- out_shape[h_axis], stride_h, kernel_h, self.padding)
- out_shape[w_axis] = get_deconv_dim(
- out_shape[w_axis], stride_w, kernel_w, self.padding)
+ out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
+ kernel_h, self.padding)
+ out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
+ kernel_w, self.padding)
outputs.set_shape(out_shape)
if self.bias:
@@ -1152,7 +1142,7 @@ def conv2d_transpose(inputs,
trainable=True,
name=None,
reuse=None):
- """Transposed convolution layer (sometimes called Deconvolution).
+ """Functional interface for transposed 2D convolution layer.
The need for transposed convolutions generally arises
from the desire to use a transformation going in the opposite direction
@@ -1177,6 +1167,250 @@ def conv2d_transpose(inputs,
`channels_last` corresponds to inputs with shape
`(batch, height, width, channels)` while `channels_first` corresponds to
inputs with shape `(batch, channels, height, width)`.
+ activation: Activation function. Set it to `None` to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: An initializer for the convolution kernel.
+ bias_initializer: An initializer for the bias vector. If `None`, then no
+ bias will be applied.
+ kernel_regularizer: Optional regularizer for the convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ reuse: Boolean, whether to reuse the weights of a previous layer
+ by the same name.
+
+ Returns:
+ Output tensor.
+ """
+ layer = Conv2DTranspose(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ name=name,
+ _reuse=reuse,
+ _scope=name)
+ return layer.apply(inputs)
+
+
+class Conv3DTranspose(Conv3D):
+ """Transposed 3D convolution layer (sometimes called 3D Deconvolution).
+
+ Arguments:
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: An integer or tuple/list of 3 integers, specifying the
+ depth, height and width of the 3D convolution window.
+ Can be a single integer to specify the same value for all spatial
+ dimensions.
+ strides: An integer or tuple/list of 3 integers, specifying the strides
+ of the convolution along the depth, height and width.
+ Can be a single integer to specify the same value for all spatial
+ dimensions.
+ padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, depth, height, width, channels)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, channels, depth, height, width)`.
+ activation: Activation function. Set it to `None` to maintain a
+ linear activation.
+ use_bias: Boolean, whether the layer uses a bias.
+ kernel_initializer: An initializer for the convolution kernel.
+ bias_initializer: An initializer for the bias vector. If `None`, then no
+ bias will be applied.
+ kernel_regularizer: Optional regularizer for the convolution kernel.
+ bias_regularizer: Optional regularizer for the bias vector.
+ activity_regularizer: Regularizer function for the output.
+ trainable: Boolean, if `True` also add variables to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: A string, the name of the layer.
+ """
+
+ def __init__(self,
+ filters,
+ kernel_size,
+ strides=(1, 1, 1),
+ padding='valid',
+ data_format='channels_last',
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ trainable=True,
+ name=None,
+ **kwargs):
+ super(Conv3DTranspose, self).__init__(
+ filters=filters,
+ kernel_size=kernel_size,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ activation=activation,
+ use_bias=use_bias,
+ kernel_initializer=kernel_initializer,
+ bias_initializer=bias_initializer,
+ kernel_regularizer=kernel_regularizer,
+ bias_regularizer=bias_regularizer,
+ activity_regularizer=activity_regularizer,
+ trainable=trainable,
+ name=name,
+ **kwargs)
+
+ def build(self, input_shape):
+ if len(input_shape) != 5:
+ raise ValueError('Inputs should have rank 5, received input shape:',
+ str(input_shape))
+ if self.data_format == 'channels_first':
+ channel_axis = 1
+ else:
+ channel_axis = -1
+ if input_shape[channel_axis] is None:
+ raise ValueError('The channel dimension of the inputs '
+ 'should be defined, found None: ' + str(input_shape))
+ input_dim = input_shape[channel_axis]
+ kernel_shape = self.kernel_size + (self.filters, input_dim)
+
+ self.kernel = self.add_variable(
+ 'kernel',
+ shape=kernel_shape,
+ initializer=self.kernel_initializer,
+ regularizer=self.kernel_regularizer,
+ trainable=True,
+ dtype=self.dtype)
+ if self.use_bias:
+ self.bias = self.add_variable(
+ 'bias',
+ shape=(self.filters,),
+ initializer=self.bias_initializer,
+ regularizer=self.bias_regularizer,
+ trainable=True,
+ dtype=self.dtype)
+ else:
+ self.bias = None
+
+ def call(self, inputs):
+ inputs_shape = array_ops.shape(inputs)
+ batch_size = inputs_shape[0]
+ if self.data_format == 'channels_first':
+ c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
+ else:
+ c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
+
+ depth = inputs_shape[d_axis]
+ height = inputs_shape[h_axis]
+ width = inputs_shape[w_axis]
+
+ kernel_d, kernel_h, kernel_w = self.kernel_size
+ stride_d, stride_h, stride_w = self.strides
+
+ # Infer the dynamic output shape:
+ out_depth = utils.get_deconv_dim(depth, stride_d, kernel_d, self.padding)
+ out_height = utils.get_deconv_dim(height, stride_h, kernel_h, self.padding)
+ out_width = utils.get_deconv_dim(width, stride_w, kernel_w, self.padding)
+
+ if self.data_format == 'channels_first':
+ output_shape = (batch_size, self.filters, out_depth, out_height,
+ out_width)
+ strides = (1, 1, stride_d, stride_h, stride_w)
+ else:
+ output_shape = (batch_size, out_depth, out_height, out_width,
+ self.filters)
+ strides = (1, stride_d, stride_h, stride_w, 1)
+
+ output_shape_tensor = array_ops.stack(output_shape)
+ outputs = nn.conv3d_transpose(
+ inputs,
+ self.kernel,
+ output_shape_tensor,
+ strides,
+ data_format=utils.convert_data_format(self.data_format, ndim=5),
+ padding=self.padding.upper())
+
+ # Infer the static output shape:
+ out_shape = inputs.get_shape().as_list()
+ out_shape[c_axis] = self.filters
+ out_shape[d_axis] = utils.get_deconv_dim(out_shape[d_axis], stride_d,
+ kernel_d, self.padding)
+ out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
+ kernel_h, self.padding)
+ out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
+ kernel_w, self.padding)
+ outputs.set_shape(out_shape)
+
+ if self.bias:
+ outputs_shape = outputs.shape.as_list()
+ if self.data_format == 'channels_first':
+ outputs_4d = array_ops.reshape(outputs, [
+ outputs_shape[0], outputs_shape[1],
+ outputs_shape[2] * outputs_shape[3], outputs_shape[4]
+ ])
+ else:
+ outputs_4d = array_ops.reshape(outputs, [
+ outputs_shape[0], outputs_shape[1] * outputs_shape[2],
+ outputs_shape[3], outputs_shape[4]
+ ])
+ outputs_4d = nn.bias_add(
+ outputs_4d,
+ self.bias,
+ data_format=utils.convert_data_format(self.data_format, ndim=4))
+ outputs = array_ops.reshape(outputs_4d, outputs_shape)
+
+ if self.activation is not None:
+ return self.activation(outputs)
+ return outputs
+
+
+def conv3d_transpose(inputs,
+ filters,
+ kernel_size,
+ strides=(1, 1, 1),
+ padding='valid',
+ data_format='channels_last',
+ activation=None,
+ use_bias=True,
+ kernel_initializer=None,
+ bias_initializer=init_ops.zeros_initializer(),
+ kernel_regularizer=None,
+ bias_regularizer=None,
+ activity_regularizer=None,
+ trainable=True,
+ name=None,
+ reuse=None):
+ """Functional interface for transposed 3D convolution layer.
+
+ Arguments:
+ inputs: Input tensor.
+ filters: Integer, the dimensionality of the output space (i.e. the number
+ of filters in the convolution).
+ kernel_size: A tuple or list of 3 positive integers specifying the spatial
+ dimensions of of the filters. Can be a single integer to specify the same
+ value for all spatial dimensions.
+ strides: A tuple or list of 3 positive integers specifying the strides
+ of the convolution. Can be a single integer to specify the same value for
+ all spatial dimensions.
+ padding: one of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
activation: Activation function. Set it to None to maintain a
linear activation.
use_bias: Boolean, whether the layer uses a bias.
@@ -1195,7 +1429,7 @@ def conv2d_transpose(inputs,
Returns:
Output tensor.
"""
- layer = Conv2DTranspose(
+ layer = Conv3DTranspose(
filters=filters,
kernel_size=kernel_size,
strides=strides,
@@ -1222,8 +1456,10 @@ Convolution2D = Conv2D
Convolution3D = Conv3D
SeparableConvolution2D = SeparableConv2D
Convolution2DTranspose = Deconvolution2D = Deconv2D = Conv2DTranspose
+Convolution3DTranspose = Deconvolution3D = Deconv3D = Conv3DTranspose
convolution1d = conv1d
convolution2d = conv2d
convolution3d = conv3d
separable_convolution2d = separable_conv2d
convolution2d_transpose = deconvolution2d = deconv2d = conv2d_transpose
+convolution3d_transpose = deconvolution3d = deconv3d = conv3d_transpose