aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/keras/utils/conv_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/keras/utils/conv_utils.py')
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py45
1 files changed, 34 insertions, 11 deletions
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 8ebca1418d..f486e631e5 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride):
return (output_length - 1) * stride - 2 * pad + filter_size
-def deconv_output_length(input_length, filter_size, padding, stride):
+def deconv_output_length(input_length, filter_size, padding,
+ output_padding=None, stride=0, dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
+ input_length: Integer.
+ filter_size: Integer.
+ padding: one of `"same"`, `"valid"`, `"full"`.
+ output_padding: Integer, amount of padding along the output dimension.
+ Can be set to `None` in which case the output length is inferred.
+ stride: Integer.
+ dilation: Integer.
Returns:
The output length (integer).
"""
+ assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
- input_length *= stride
- if padding == 'valid':
- input_length += max(filter_size - stride, 0)
- elif padding == 'full':
- input_length -= (stride + filter_size - 2)
- return input_length
+
+ # Get the dilated kernel size
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+
+ # Infer length if output padding is None, else compute the exact length
+ if output_padding is None:
+ if padding == 'valid':
+ length = input_length * stride + max(filter_size - stride, 0)
+ elif padding == 'full':
+ length = input_length * stride - (stride + filter_size - 2)
+ elif padding == 'same':
+ length = input_length * stride
+
+ else:
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+
+ length = ((input_length - 1) * stride + filter_size - 2 * pad +
+ output_padding)
+ return length
def normalize_data_format(value):