aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-01 11:59:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 12:05:34 -0800
commit434794582e79e2d98d984a00a5779a712a34e885 (patch)
tree10a729db3739470f48d827ec43b94a18274f7af8
parent431164534d382dda73581f22fb2a699cdba0b54f (diff)
Implement tf.nn.atrous_conv2d_transpose. Close bugs #4668 and #5300.
Change: 140759688
-rw-r--r--tensorflow/python/kernel_tests/atrous_conv2d_test.py92
-rw-r--r--tensorflow/python/ops/nn.py1
-rw-r--r--tensorflow/python/ops/nn_ops.py145
3 files changed, 210 insertions, 28 deletions
diff --git a/tensorflow/python/kernel_tests/atrous_conv2d_test.py b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
index 1dff6a9f72..162bebf8d6 100644
--- a/tensorflow/python/kernel_tests/atrous_conv2d_test.py
+++ b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
@@ -22,33 +22,33 @@ import numpy as np
import tensorflow as tf
-class AtrousConv2DTest(tf.test.TestCase):
-
- def _upsample_filters(self, filters, rate):
- """Upsamples the filters by a factor of rate along the spatial dimensions.
+def _upsample_filters(filters, rate):
+ """Upsamples the filters by a factor of rate along the spatial dimensions.
+
+ Args:
+ filters: [h, w, in_depth, out_depth]. Original filters.
+ rate: An int, specifying the upsampling rate.
+
+ Returns:
+ filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
+ h_up = h + (h - 1) * (rate - 1)
+ w_up = w + (w - 1) * (rate - 1)
+ containing (rate - 1) zeros between consecutive filter values along
+ the filters' spatial dimensions.
+ """
+ if rate == 1:
+ return filters
+ # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
+ filters_up = np.transpose(filters, [2, 3, 0, 1])
+ ker = np.zeros([rate, rate], dtype=np.float32)
+ ker[0, 0] = 1
+ filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
+ # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
+ filters_up = np.transpose(filters_up, [2, 3, 0, 1])
+ return filters_up
- Args:
- filters: [h, w, in_depth, out_depth]. Original filters.
- rate: An int, specifying the upsampling rate.
- Returns:
- filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
- h_up = h + (h - 1) * (rate - 1)
- w_up = w + (w - 1) * (rate - 1)
- containing (rate - 1) zeros between consecutive filter values along
- the filters' spatial dimensions.
- """
- if rate == 1:
- return filters
- # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
- filters_up = np.transpose(filters, [2, 3, 0, 1])
- ker = np.zeros([rate, rate])
- ker[0, 0] = 1
- filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
- # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
- filters_up = np.transpose(filters_up, [2, 3, 0, 1])
- self.assertEqual(np.sum(filters), np.sum(filters_up))
- return filters_up
+class AtrousConv2DTest(tf.test.TestCase):
def testAtrousConv2DForward(self):
with self.test_session(use_gpu=True):
@@ -65,14 +65,13 @@ class AtrousConv2DTest(tf.test.TestCase):
f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)
for rate in range(1, 4):
- f_up = self._upsample_filters(f, rate)
+ f_up = _upsample_filters(f, rate)
for padding in ["SAME", "VALID"]:
y1 = tf.nn.atrous_conv2d(x, f, rate, padding=padding)
y2 = tf.nn.conv2d(x, f_up, strides=[1, 1, 1, 1],
padding=padding)
- self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2,
- atol=1e-2)
+ self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)
def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions.
@@ -150,5 +149,42 @@ class AtrousConv2DTest(tf.test.TestCase):
self.assertLess(err, err_tolerance)
+class AtrousConv2DTransposeTest(tf.test.TestCase):
+
+ def testAtrousConv2DTransposeForward(self):
+ with self.test_session(use_gpu=True):
+ # Input: [batch, height, width, input_depth]
+ height = 9
+ for width in [9, 10]: # Test both odd and even width.
+ x_shape = [2, height, width, 2]
+ x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
+
+ # Filter: [kernel_height, kernel_width, input_depth, output_depth]
+ for kernel_height in range(1, 4):
+ for kernel_width in range(1, 4):
+ f_shape = [kernel_height, kernel_width, 2, 2]
+ f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)
+
+ for rate in range(1, 4):
+ f_up = _upsample_filters(f, rate)
+ kernel_height_up = (kernel_height +
+ (kernel_height - 1) * (rate - 1))
+ kernel_width_up = kernel_width + (kernel_width - 1) * (rate - 1)
+
+ for padding in ["SAME", "VALID"]:
+ if padding == "SAME":
+ y_shape = [2, height, width, 2]
+ else:
+ y_shape = [2,
+ height + kernel_height_up - 1,
+ width + kernel_width_up - 1,
+ 2]
+
+ y1 = tf.nn.atrous_conv2d_transpose(x, f, y_shape, rate, padding)
+ y2 = tf.nn.conv2d_transpose(
+ x, f_up, y_shape, strides=[1, 1, 1, 1], padding=padding)
+ self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)
+
+
if __name__ == "__main__":
tf.test.main()
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 601984799c..da1b880d32 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -110,6 +110,7 @@ concatenated.
@@depthwise_conv2d_native
@@separable_conv2d
@@atrous_conv2d
+@@atrous_conv2d_transpose
@@conv2d_transpose
@@conv1d
@@conv3d
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 35610cc554..b2c6cf7138 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1078,6 +1078,151 @@ def conv2d_transpose(value,
name=name)
+def atrous_conv2d_transpose(value,
+ filters,
+ output_shape,
+ rate,
+ padding,
+ name=None):
+ """The transpose of `atrous_conv2d`.
+
+ This operation is sometimes called "deconvolution" after [Deconvolutional
+ Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is
+ actually the transpose (gradient) of `atrous_conv2d` rather than an actual
+ deconvolution.
+
+ Args:
+ value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC`
+ format. Its shape is `[batch, in_height, in_width, in_channels]`.
+ filters: A 4-D `Tensor` with the same type as `value` and shape
+ `[filter_height, filter_width, out_channels, in_channels]`. `filters`'
+ `in_channels` dimension must match that of `value`. Atrous convolution is
+ equivalent to standard convolution with upsampled filters with effective
+ height `filter_height + (filter_height - 1) * (rate - 1)` and effective
+ width `filter_width + (filter_width - 1) * (rate - 1)`, produced by
+ inserting `rate - 1` zeros along consecutive elements across the
+ `filters`' spatial dimensions.
+ output_shape: A 1-D `Tensor` of shape representing the output shape of the
+ deconvolution op.
+ rate: A positive int32. The stride with which we sample input values across
+ the `height` and `width` dimensions. Equivalently, the rate by which we
+ upsample the filter values by inserting zeros across the `height` and
+ `width` dimensions. In the literature, the same parameter is sometimes
+ called `input stride` or `dilation`.
+ padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm.
+ name: Optional name for the returned tensor.
+
+ Returns:
+ A `Tensor` with the same type as `value`.
+
+ Raises:
+ ValueError: If input/output depth does not match `filters`' shape, or if
+ padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less
+ than one, or if the output_shape is not a tensor with 4 elements.
+ """
+ with ops.name_scope(name, "atrous_conv2d_transpose",
+ [value, filters, output_shape]) as name:
+ value = ops.convert_to_tensor(value, name="value")
+ filters = ops.convert_to_tensor(filters, name="filters")
+ if not value.get_shape()[3].is_compatible_with(filters.get_shape()[3]):
+ raise ValueError(
+ "value's input channels does not match filters' input channels, "
+ "{} != {}".format(value.get_shape()[3], filters.get_shape()[3]))
+ if rate < 1:
+ raise ValueError("rate {} cannot be less than one".format(rate))
+
+ if rate == 1:
+ return conv2d_transpose(value,
+ filters,
+ output_shape,
+ strides=[1, 1, 1, 1],
+ padding=padding,
+ data_format="NHWC")
+
+ output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
+ if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)):
+ raise ValueError("output_shape must have shape (4,), got {}"
+ .format(output_shape_.get_shape()))
+
+ if isinstance(output_shape, (list, np.ndarray)):
+ # output_shape's shape should be == [4] if reached this point.
+ if not filters.get_shape()[2].is_compatible_with(output_shape[3]):
+ raise ValueError(
+ "output_shape does not match filter's output channels, "
+ "{} != {}".format(output_shape[3], filters.get_shape()[2]))
+
+ # We have two padding contributions. The first is used for converting "SAME"
+ # to "VALID". The second is required so that the height and width of the
+ # zero-padded value tensor are multiples of rate.
+
+ # Padding required to reduce to "VALID" convolution
+ if padding == "SAME":
+ # Handle filters whose shape is unknown during graph creation.
+ if filters.get_shape().is_fully_defined():
+ filter_shape = filters.get_shape().as_list()
+ else:
+ filter_shape = array_ops.shape(filters)
+ filter_height, filter_width = filter_shape[0], filter_shape[1]
+
+ # Spatial dimensions of the filters and the upsampled filters in which we
+ # introduce (rate - 1) zeros between consecutive filter values.
+ filter_height_up = filter_height + (filter_height - 1) * (rate - 1)
+ filter_width_up = filter_width + (filter_width - 1) * (rate - 1)
+
+ pad_height = filter_height_up - 1
+ pad_width = filter_width_up - 1
+
+ # When pad_height (pad_width) is odd, we pad more to bottom (right),
+ # following the same convention as conv2d().
+ pad_top = pad_height // 2
+ pad_bottom = pad_height - pad_top
+ pad_left = pad_width // 2
+ pad_right = pad_width - pad_left
+ elif padding == "VALID":
+ pad_top = 0
+ pad_bottom = 0
+ pad_left = 0
+ pad_right = 0
+ else:
+ raise ValueError("padding must be either VALID or SAME:"
+ " {}".format(padding))
+
+ in_height = output_shape[1] + pad_top + pad_bottom
+ in_width = output_shape[2] + pad_left + pad_right
+
+ # More padding so that rate divides the height and width of the input.
+ pad_bottom_extra = (rate - in_height % rate) % rate
+ pad_right_extra = (rate - in_width % rate) % rate
+
+ # The paddings argument to space_to_batch is just the extra padding
+ # component.
+ space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]]
+
+ value = array_ops.space_to_batch(input=value,
+ paddings=space_to_batch_pad,
+ block_size=rate)
+
+ input_sizes = [rate * rate * output_shape[0],
+ (in_height + pad_bottom_extra) // rate,
+ (in_width + pad_right_extra) // rate,
+ output_shape[3]]
+
+ value = gen_nn_ops.conv2d_backprop_input(input_sizes=input_sizes,
+ filter=filters,
+ out_backprop=value,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ data_format="NHWC")
+
+ # The crops argument to batch_to_space includes both padding components.
+ batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
+ [pad_left, pad_right + pad_right_extra]]
+
+ return array_ops.batch_to_space(input=value,
+ crops=batch_to_space_crop,
+ block_size=rate)
+
+
def conv3d_transpose(value,
filter,
output_shape,