diff options
author | Francois Chollet <fchollet@google.com> | 2018-10-08 10:43:01 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-08 10:49:56 -0700 |
commit | 8ef3e7c8c053cb6dad530e13c478bbd406ea2c95 (patch) | |
tree | 74f36c8bd9293854ce0ee1f8a9bac04a863bfe99 /tensorflow/python/keras/layers/convolutional.py | |
parent | 153decedefc8da1fbd0717f4223b4b053e7aa517 (diff) |
Part 1/3 of the feature sync to the Keras 2.2.4 API.
PiperOrigin-RevId: 216211279
Diffstat (limited to 'tensorflow/python/keras/layers/convolutional.py')
-rw-r--r-- | tensorflow/python/keras/layers/convolutional.py | 177 |
1 files changed, 126 insertions, 51 deletions
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py index d00def07bb..8f5872385c 100644 --- a/tensorflow/python/keras/layers/convolutional.py +++ b/tensorflow/python/keras/layers/convolutional.py @@ -645,6 +645,14 @@ class Conv2DTranspose(Conv2D): Specifying any stride value != 1 is incompatible with specifying any `dilation_rate` value != 1. padding: one of `"valid"` or `"same"` (case-insensitive). + output_padding: An integer or tuple/list of 2 integers, + specifying the amount of padding along the height and width + of the output tensor. + Can be a single integer to specify the same value for all + spatial dimensions. + The amount of output padding along a given dimension must be + lower than the stride along that same dimension. + If set to `None` (default), the output shape is inferred. data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. @@ -700,7 +708,9 @@ class Conv2DTranspose(Conv2D): kernel_size, strides=(1, 1), padding='valid', + output_padding=None, data_format=None, + dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform', @@ -717,6 +727,7 @@ class Conv2DTranspose(Conv2D): strides=strides, padding=padding, data_format=data_format, + dilation_rate=dilation_rate, activation=activations.get(activation), use_bias=use_bias, kernel_initializer=initializers.get(kernel_initializer), @@ -728,6 +739,16 @@ class Conv2DTranspose(Conv2D): bias_constraint=constraints.get(bias_constraint), **kwargs) + self.output_padding = output_padding + if self.output_padding is not None: + self.output_padding = conv_utils.normalize_tuple( + self.output_padding, 2, 'output_padding') + for stride, out_pad in zip(self.strides, self.output_padding): + if out_pad >= stride: + raise ValueError('Stride ' + str(self.strides) + ' must be ' + 'greater than output padding ' + + str(self.output_padding)) + def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if len(input_shape) != 4: @@ -769,51 +790,50 @@ class Conv2DTranspose(Conv2D): inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': - c_axis, h_axis, w_axis = 1, 2, 3 + h_axis, w_axis = 2, 3 else: - c_axis, h_axis, w_axis = 3, 1, 2 + h_axis, w_axis = 1, 2 height, width = inputs_shape[h_axis], inputs_shape[w_axis] kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides + if self.output_padding is None: + out_pad_h = out_pad_w = None + else: + out_pad_h, out_pad_w = self.output_padding + # Infer the dynamic output shape: out_height = conv_utils.deconv_output_length(height, kernel_h, - self.padding, - stride_h) + padding=self.padding, + output_padding=out_pad_h, + stride=stride_h, + dilation=self.dilation_rate[0]) out_width = conv_utils.deconv_output_length(width, kernel_w, - self.padding, - stride_w) + padding=self.padding, + output_padding=out_pad_w, + stride=stride_w, + dilation=self.dilation_rate[1]) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_height, out_width) - strides = (1, 1, stride_h, stride_w) else: output_shape = (batch_size, out_height, out_width, self.filters) - strides = (1, stride_h, stride_w, 1) output_shape_tensor = array_ops.stack(output_shape) - outputs = nn.conv2d_transpose( + outputs = backend.conv2d_transpose( inputs, self.kernel, output_shape_tensor, - strides, - padding=self.padding.upper(), - data_format=conv_utils.convert_data_format(self.data_format, ndim=4)) + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate) if not context.executing_eagerly(): # Infer the static output shape: - out_shape = inputs.get_shape().as_list() - out_shape[c_axis] = self.filters - out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis], - kernel_h, - self.padding, - stride_h) - out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis], - kernel_w, - self.padding, - stride_w) + out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: @@ -837,13 +857,33 @@ class Conv2DTranspose(Conv2D): kernel_h, kernel_w = self.kernel_size stride_h, stride_w = self.strides + if self.output_padding is None: + out_pad_h = out_pad_w = None + else: + out_pad_h, out_pad_w = self.output_padding + output_shape[c_axis] = self.filters output_shape[h_axis] = conv_utils.deconv_output_length( - output_shape[h_axis], kernel_h, self.padding, stride_h) + output_shape[h_axis], + kernel_h, + padding=self.padding, + output_padding=out_pad_h, + stride=stride_h, + dilation=self.dilation_rate[0]) output_shape[w_axis] = conv_utils.deconv_output_length( - output_shape[w_axis], kernel_w, self.padding, stride_w) + output_shape[w_axis], + kernel_w, + padding=self.padding, + output_padding=out_pad_w, + stride=stride_w, + dilation=self.dilation_rate[1]) return tensor_shape.TensorShape(output_shape) + def get_config(self): + config = super(Conv2DTranspose, self).get_config() + config['output_padding'] = self.output_padding + return config + @tf_export('keras.layers.Conv3DTranspose', 'keras.layers.Convolution3DTranspose') @@ -878,6 +918,14 @@ class Conv3DTranspose(Conv3D): Specifying any stride value != 1 is incompatible with specifying any `dilation_rate` value != 1. padding: one of `"valid"` or `"same"` (case-insensitive). + output_padding: An integer or tuple/list of 3 integers, + specifying the amount of padding along the depth, height, and + width. + Can be a single integer to specify the same value for all + spatial dimensions. + The amount of output padding along a given dimension must be + lower than the stride along that same dimension. + If set to `None` (default), the output shape is inferred. data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. @@ -943,6 +991,7 @@ class Conv3DTranspose(Conv3D): kernel_size, strides=(1, 1, 1), padding='valid', + output_padding=None, data_format=None, activation=None, use_bias=True, @@ -971,6 +1020,16 @@ class Conv3DTranspose(Conv3D): bias_constraint=constraints.get(bias_constraint), **kwargs) + self.output_padding = output_padding + if self.output_padding is not None: + self.output_padding = conv_utils.normalize_tuple( + self.output_padding, 3, 'output_padding') + for stride, out_pad in zip(self.strides, self.output_padding): + if out_pad >= stride: + raise ValueError('Stride ' + str(self.strides) + ' must be ' + 'greater than output padding ' + + str(self.output_padding)) + def build(self, input_shape): input_shape = tensor_shape.TensorShape(input_shape) if len(input_shape) != 5: @@ -1012,11 +1071,9 @@ class Conv3DTranspose(Conv3D): inputs_shape = array_ops.shape(inputs) batch_size = inputs_shape[0] if self.data_format == 'channels_first': - c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4 + d_axis, h_axis, w_axis = 2, 3, 4 else: - c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3 - - self.input_spec = InputSpec(ndim=5, axes={c_axis: inputs_shape[c_axis]}) + d_axis, h_axis, w_axis = 1, 2, 3 depth = inputs_shape[d_axis] height = inputs_shape[h_axis] @@ -1025,19 +1082,27 @@ class Conv3DTranspose(Conv3D): kernel_d, kernel_h, kernel_w = self.kernel_size stride_d, stride_h, stride_w = self.strides + if self.output_padding is None: + out_pad_d = out_pad_h = out_pad_w = None + else: + out_pad_d, out_pad_h, out_pad_w = self.output_padding + # Infer the dynamic output shape: out_depth = conv_utils.deconv_output_length(depth, kernel_d, - self.padding, - stride_d) + padding=self.padding, + output_padding=out_pad_d, + stride=stride_d) out_height = conv_utils.deconv_output_length(height, kernel_h, - self.padding, - stride_h) + padding=self.padding, + output_padding=out_pad_h, + stride=stride_h) out_width = conv_utils.deconv_output_length(width, kernel_w, - self.padding, - stride_w) + padding=self.padding, + output_padding=out_pad_w, + stride=stride_w) if self.data_format == 'channels_first': output_shape = (batch_size, self.filters, out_depth, out_height, out_width) @@ -1058,20 +1123,7 @@ class Conv3DTranspose(Conv3D): if not context.executing_eagerly(): # Infer the static output shape: - out_shape = inputs.get_shape().as_list() - out_shape[c_axis] = self.filters - out_shape[d_axis] = conv_utils.deconv_output_length(out_shape[d_axis], - kernel_d, - self.padding, - stride_d) - out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis], - kernel_h, - self.padding, - stride_h) - out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis], - kernel_w, - self.padding, - stride_w) + out_shape = self.compute_output_shape(inputs.shape) outputs.set_shape(out_shape) if self.use_bias: @@ -1109,15 +1161,38 @@ class Conv3DTranspose(Conv3D): kernel_d, kernel_h, kernel_w = self.kernel_size stride_d, stride_h, stride_w = self.strides + if self.output_padding is None: + out_pad_d = out_pad_h = out_pad_w = None + else: + out_pad_d, out_pad_h, out_pad_w = self.output_padding + output_shape[c_axis] = self.filters output_shape[d_axis] = conv_utils.deconv_output_length( - output_shape[d_axis], kernel_d, self.padding, stride_d) + output_shape[d_axis], + kernel_d, + padding=self.padding, + output_padding=out_pad_d, + stride=stride_d) output_shape[h_axis] = conv_utils.deconv_output_length( - output_shape[h_axis], kernel_h, self.padding, stride_h) + output_shape[h_axis], + kernel_h, + padding=self.padding, + output_padding=out_pad_h, + stride=stride_h) output_shape[w_axis] = conv_utils.deconv_output_length( - output_shape[w_axis], kernel_w, self.padding, stride_w) + output_shape[w_axis], + kernel_w, + padding=self.padding, + output_padding=out_pad_w, + stride=stride_w) return tensor_shape.TensorShape(output_shape) + def get_config(self): + config = super(Conv3DTranspose, self).get_config() + config.pop('dilation_rate') + config['output_padding'] = self.output_padding + return config + class SeparableConv(Conv): """Abstract base layer for separable nD convolution. |