diff options
Diffstat (limited to 'tensorflow/python/keras/utils/conv_utils.py')
-rw-r--r-- | tensorflow/python/keras/utils/conv_utils.py | 45 |
1 files changed, 34 insertions, 11 deletions
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py index 8ebca1418d..f486e631e5 100644 --- a/tensorflow/python/keras/utils/conv_utils.py +++ b/tensorflow/python/keras/utils/conv_utils.py @@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride): return (output_length - 1) * stride - 2 * pad + filter_size -def deconv_output_length(input_length, filter_size, padding, stride): +def deconv_output_length(input_length, filter_size, padding, + output_padding=None, stride=0, dilation=1): """Determines output length of a transposed convolution given input length. Arguments: - input_length: integer. - filter_size: integer. - padding: one of "same", "valid", "full". - stride: integer. + input_length: Integer. + filter_size: Integer. + padding: one of `"same"`, `"valid"`, `"full"`. + output_padding: Integer, amount of padding along the output dimension. + Can be set to `None` in which case the output length is inferred. + stride: Integer. + dilation: Integer. Returns: The output length (integer). """ + assert padding in {'same', 'valid', 'full'} if input_length is None: return None - input_length *= stride - if padding == 'valid': - input_length += max(filter_size - stride, 0) - elif padding == 'full': - input_length -= (stride + filter_size - 2) - return input_length + + # Get the dilated kernel size + filter_size = filter_size + (filter_size - 1) * (dilation - 1) + + # Infer length if output padding is None, else compute the exact length + if output_padding is None: + if padding == 'valid': + length = input_length * stride + max(filter_size - stride, 0) + elif padding == 'full': + length = input_length * stride - (stride + filter_size - 2) + elif padding == 'same': + length = input_length * stride + + else: + if padding == 'same': + pad = filter_size // 2 + elif padding == 'valid': + pad = 0 + elif padding == 'full': + pad = filter_size - 1 + + length = ((input_length - 1) * stride + filter_size - 2 * pad + + output_padding) + return length def normalize_data_format(value): |