aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/ops/nn_ops.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/ops/nn_ops.py')
-rw-r--r--tensorflow/python/ops/nn_ops.py107
1 files changed, 16 insertions, 91 deletions
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index ccce9402c7..61fda3a798 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -840,6 +840,11 @@ def pool(input, # pylint: disable=redefined-builtin
def atrous_conv2d(value, filters, rate, padding, name=None):
"""Atrous convolution (a.k.a. convolution with holes or dilated convolution).
+ This function is a simpler wrapper around the more general
+ @{tf.nn.convolution}, and exists only for backwards compatibility. You can
+ use @{tf.nn.convolution} to perform 1-D, 2-D, or 3-D atrous convolution.
+
+
Computes a 2-D atrous convolution, also known as convolution with holes or
dilated convolution, given 4-D `value` and `filters` tensors. If the `rate`
parameter is equal to one, it performs regular 2-D convolution. If the `rate`
@@ -959,93 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
ValueError: If input/output depth does not match `filters`' shape, or if
padding is other than `'VALID'` or `'SAME'`.
"""
- with ops.name_scope(name, "atrous_conv2d", [value, filters]) as name:
- value = ops.convert_to_tensor(value, name="value")
- filters = ops.convert_to_tensor(filters, name="filters")
- if not value.get_shape()[3].is_compatible_with(filters.get_shape()[2]):
- raise ValueError(
- "value's input channels does not match filters' input channels, "
- "{} != {}".format(value.get_shape()[3], filters.get_shape()[2]))
- if rate < 1:
- raise ValueError("rate {} cannot be less than one".format(rate))
-
- if rate == 1:
- value = gen_nn_ops.conv2d(input=value,
- filter=filters,
- strides=[1, 1, 1, 1],
- padding=padding)
- return value
-
- # We have two padding contributions. The first is used for converting "SAME"
- # to "VALID". The second is required so that the height and width of the
- # zero-padded value tensor are multiples of rate.
-
- # Padding required to reduce to "VALID" convolution
- if padding == "SAME":
- # Handle filters whose shape is unknown during graph creation.
- if filters.get_shape().is_fully_defined():
- filter_shape = filters.get_shape().as_list()
- else:
- filter_shape = array_ops.shape(filters)
- filter_height, filter_width = filter_shape[0], filter_shape[1]
-
- # Spatial dimensions of the filters and the upsampled filters in which we
- # introduce (rate - 1) zeros between consecutive filter values.
- filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
- filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
-
- pad_height = filter_height_up - 1
- pad_width = filter_width_up - 1
-
- # When pad_height (pad_width) is odd, we pad more to bottom (right),
- # following the same convention as conv2d().
- pad_top = pad_height // 2
- pad_bottom = pad_height - pad_top
- pad_left = pad_width // 2
- pad_right = pad_width - pad_left
- elif padding == "VALID":
- pad_top = 0
- pad_bottom = 0
- pad_left = 0
- pad_right = 0
- else:
- raise ValueError("Invalid padding")
-
- # Handle input whose shape is unknown during graph creation.
- if value.get_shape().is_fully_defined():
- value_shape = value.get_shape().as_list()
- else:
- value_shape = array_ops.shape(value)
-
- in_height = value_shape[1] + pad_top + pad_bottom
- in_width = value_shape[2] + pad_left + pad_right
-
- # More padding so that rate divides the height and width of the input.
- pad_bottom_extra = (rate - in_height % rate) % rate
- pad_right_extra = (rate - in_width % rate) % rate
-
- # The paddings argument to space_to_batch includes both padding components.
- space_to_batch_pad = [[pad_top, pad_bottom + pad_bottom_extra],
- [pad_left, pad_right + pad_right_extra]]
-
- value = array_ops.space_to_batch(input=value,
- paddings=space_to_batch_pad,
- block_size=rate)
-
- value = gen_nn_ops.conv2d(input=value,
- filter=filters,
- strides=[1, 1, 1, 1],
- padding="VALID",
- name=name)
-
- # The crops argument to batch_to_space is just the extra padding component.
- batch_to_space_crop = [[0, pad_bottom_extra], [0, pad_right_extra]]
-
- value = array_ops.batch_to_space(input=value,
- crops=batch_to_space_crop,
- block_size=rate)
-
- return value
+ return convolution(
+ input=value,
+ filter=filters,
+ padding=padding,
+ dilation_rate=np.broadcast_to(rate, (2,)),
+ name=name)
def conv2d_transpose(value,
@@ -1272,7 +1196,7 @@ def conv3d_transpose(value,
output_shape,
strides,
padding="SAME",
- data_format=None,
+ data_format="NDHWC",
name=None):
"""The transpose of `conv3d`.
@@ -1308,10 +1232,11 @@ def conv3d_transpose(value,
[value, filter, output_shape]) as name:
value = ops.convert_to_tensor(value, name="value")
filter = ops.convert_to_tensor(filter, name="filter")
- if not value.get_shape()[4].is_compatible_with(filter.get_shape()[4]):
+ axis = 1 if data_format == "NCDHW" else 4
+ if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
raise ValueError("input channels does not match filter's input channels, "
- "{} != {}".format(value.get_shape()[4], filter.get_shape(
- )[4]))
+ "{} != {}".format(value.get_shape()[axis],
+ filter.get_shape()[4]))
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):