diff options
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 107 |
1 files changed, 16 insertions, 91 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index ccce9402c7..61fda3a798 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -840,6 +840,11 @@ def pool(input, # pylint: disable=redefined-builtin def atrous_conv2d(value, filters, rate, padding, name=None): """Atrous convolution (a.k.a. convolution with holes or dilated convolution). + This function is a simpler wrapper around the more general + @{tf.nn.convolution}, and exists only for backwards compatibility. You can + use @{tf.nn.convolution} to perform 1-D, 2-D, or 3-D atrous convolution. + + Computes a 2-D atrous convolution, also known as convolution with holes or dilated convolution, given 4-D `value` and `filters` tensors. If the `rate` parameter is equal to one, it performs regular 2-D convolution. If the `rate` @@ -959,93 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None): ValueError: If input/output depth does not match `filters`' shape, or if padding is other than `'VALID'` or `'SAME'`. """ - with ops.name_scope(name, "atrous_conv2d", [value, filters]) as name: - value = ops.convert_to_tensor(value, name="value") - filters = ops.convert_to_tensor(filters, name="filters") - if not value.get_shape()[3].is_compatible_with(filters.get_shape()[2]): - raise ValueError( - "value's input channels does not match filters' input channels, " - "{} != {}".format(value.get_shape()[3], filters.get_shape()[2])) - if rate < 1: - raise ValueError("rate {} cannot be less than one".format(rate)) - - if rate == 1: - value = gen_nn_ops.conv2d(input=value, - filter=filters, - strides=[1, 1, 1, 1], - padding=padding) - return value - - # We have two padding contributions. The first is used for converting "SAME" - # to "VALID". The second is required so that the height and width of the - # zero-padded value tensor are multiples of rate. - - # Padding required to reduce to "VALID" convolution - if padding == "SAME": - # Handle filters whose shape is unknown during graph creation. - if filters.get_shape().is_fully_defined(): - filter_shape = filters.get_shape().as_list() - else: - filter_shape = array_ops.shape(filters) - filter_height, filter_width = filter_shape[0], filter_shape[1] - - # Spatial dimensions of the filters and the upsampled filters in which we - # introduce (rate - 1) zeros between consecutive filter values. - filter_height_up = filter_height + (filter_height - 1) * (rate - 1) - filter_width_up = filter_width + (filter_width - 1) * (rate - 1) - - pad_height = filter_height_up - 1 - pad_width = filter_width_up - 1 - - # When pad_height (pad_width) is odd, we pad more to bottom (right), - # following the same convention as conv2d(). - pad_top = pad_height // 2 - pad_bottom = pad_height - pad_top - pad_left = pad_width // 2 - pad_right = pad_width - pad_left - elif padding == "VALID": - pad_top = 0 - pad_bottom = 0 - pad_left = 0 - pad_right = 0 - else: - raise ValueError("Invalid padding") - - # Handle input whose shape is unknown during graph creation. - if value.get_shape().is_fully_defined(): - value_shape = value.get_shape().as_list() - else: - value_shape = array_ops.shape(value) - - in_height = value_shape[1] + pad_top + pad_bottom - in_width = value_shape[2] + pad_left + pad_right - - # More padding so that rate divides the height and width of the input. - pad_bottom_extra = (rate - in_height % rate) % rate - pad_right_extra = (rate - in_width % rate) % rate - - # The paddings argument to space_to_batch includes both padding components. - space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra], - [pad_left, pad_right + pad_right_extra]] - - value = array_ops.space_to_batch(input=value, - paddings=space_to_batch_pad, - block_size=rate) - - value = gen_nn_ops.conv2d(input=value, - filter=filters, - strides=[1, 1, 1, 1], - padding="VALID", - name=name) - - # The crops argument to batch_to_space is just the extra padding component. - batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]] - - value = array_ops.batch_to_space(input=value, - crops=batch_to_space_crop, - block_size=rate) - - return value + return convolution( + input=value, + filter=filters, + padding=padding, + dilation_rate=np.broadcast_to(rate, (2,)), + name=name) def conv2d_transpose(value, @@ -1272,7 +1196,7 @@ def conv3d_transpose(value, output_shape, strides, padding="SAME", - data_format=None, + data_format="NDHWC", name=None): """The transpose of `conv3d`. @@ -1308,10 +1232,11 @@ def conv3d_transpose(value, [value, filter, output_shape]) as name: value = ops.convert_to_tensor(value, name="value") filter = ops.convert_to_tensor(filter, name="filter") - if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]): + axis = 1 if data_format == "NCDHW" else 4 + if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]): raise ValueError("input channels does not match filter's input channels, " - "{} != {}".format(value.get_shape()[4], filter.get_shape( - )[4])) + "{} != {}".format(value.get_shape()[axis], + filter.get_shape()[4])) output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)): |