diff options
author | 2016-12-01 11:59:22 -0800 | |
---|---|---|
committer | 2016-12-01 12:05:34 -0800 | |
commit | 434794582e79e2d98d984a00a5779a712a34e885 (patch) | |
tree | 10a729db3739470f48d827ec43b94a18274f7af8 | |
parent | 431164534d382dda73581f22fb2a699cdba0b54f (diff) |
Implement tf.nn.atrous_conv2d_transpose. Close bugs #4668 and #5300.
Change: 140759688
-rw-r--r-- | tensorflow/python/kernel_tests/atrous_conv2d_test.py | 92 | ||||
-rw-r--r-- | tensorflow/python/ops/nn.py | 1 | ||||
-rw-r--r-- | tensorflow/python/ops/nn_ops.py | 145 |
3 files changed, 210 insertions, 28 deletions
diff --git a/tensorflow/python/kernel_tests/atrous_conv2d_test.py b/tensorflow/python/kernel_tests/atrous_conv2d_test.py index 1dff6a9f72..162bebf8d6 100644 --- a/tensorflow/python/kernel_tests/atrous_conv2d_test.py +++ b/tensorflow/python/kernel_tests/atrous_conv2d_test.py @@ -22,33 +22,33 @@ import numpy as np import tensorflow as tf -class AtrousConv2DTest(tf.test.TestCase): - - def _upsample_filters(self, filters, rate): - """Upsamples the filters by a factor of rate along the spatial dimensions. +def _upsample_filters(filters, rate): + """Upsamples the filters by a factor of rate along the spatial dimensions. + + Args: + filters: [h, w, in_depth, out_depth]. Original filters. + rate: An int, specifying the upsampling rate. + + Returns: + filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with + h_up = h + (h - 1) * (rate - 1) + w_up = w + (w - 1) * (rate - 1) + containing (rate - 1) zeros between consecutive filter values along + the filters' spatial dimensions. + """ + if rate == 1: + return filters + # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w] + filters_up = np.transpose(filters, [2, 3, 0, 1]) + ker = np.zeros([rate, rate], dtype=np.float32) + ker[0, 0] = 1 + filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)] + # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth] + filters_up = np.transpose(filters_up, [2, 3, 0, 1]) + return filters_up - Args: - filters: [h, w, in_depth, out_depth]. Original filters. - rate: An int, specifying the upsampling rate. - Returns: - filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with - h_up = h + (h - 1) * (rate - 1) - w_up = w + (w - 1) * (rate - 1) - containing (rate - 1) zeros between consecutive filter values along - the filters' spatial dimensions. - """ - if rate == 1: - return filters - # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w] - filters_up = np.transpose(filters, [2, 3, 0, 1]) - ker = np.zeros([rate, rate]) - ker[0, 0] = 1 - filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)] - # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth] - filters_up = np.transpose(filters_up, [2, 3, 0, 1]) - self.assertEqual(np.sum(filters), np.sum(filters_up)) - return filters_up +class AtrousConv2DTest(tf.test.TestCase): def testAtrousConv2DForward(self): with self.test_session(use_gpu=True): @@ -65,14 +65,13 @@ class AtrousConv2DTest(tf.test.TestCase): f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape) for rate in range(1, 4): - f_up = self._upsample_filters(f, rate) + f_up = _upsample_filters(f, rate) for padding in ["SAME", "VALID"]: y1 = tf.nn.atrous_conv2d(x, f, rate, padding=padding) y2 = tf.nn.conv2d(x, f_up, strides=[1, 1, 1, 1], padding=padding) - self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2, - atol=1e-2) + self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3) def testAtrousSequence(self): """Tests optimization of sequence of atrous convolutions. @@ -150,5 +149,42 @@ class AtrousConv2DTest(tf.test.TestCase): self.assertLess(err, err_tolerance) +class AtrousConv2DTransposeTest(tf.test.TestCase): + + def testAtrousConv2DTransposeForward(self): + with self.test_session(use_gpu=True): + # Input: [batch, height, width, input_depth] + height = 9 + for width in [9, 10]: # Test both odd and even width. + x_shape = [2, height, width, 2] + x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape) + + # Filter: [kernel_height, kernel_width, input_depth, output_depth] + for kernel_height in range(1, 4): + for kernel_width in range(1, 4): + f_shape = [kernel_height, kernel_width, 2, 2] + f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape) + + for rate in range(1, 4): + f_up = _upsample_filters(f, rate) + kernel_height_up = (kernel_height + + (kernel_height - 1) * (rate - 1)) + kernel_width_up = kernel_width + (kernel_width - 1) * (rate - 1) + + for padding in ["SAME", "VALID"]: + if padding == "SAME": + y_shape = [2, height, width, 2] + else: + y_shape = [2, + height + kernel_height_up - 1, + width + kernel_width_up - 1, + 2] + + y1 = tf.nn.atrous_conv2d_transpose(x, f, y_shape, rate, padding) + y2 = tf.nn.conv2d_transpose( + x, f_up, y_shape, strides=[1, 1, 1, 1], padding=padding) + self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3) + + if __name__ == "__main__": tf.test.main() diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py index 601984799c..da1b880d32 100644 --- a/tensorflow/python/ops/nn.py +++ b/tensorflow/python/ops/nn.py @@ -110,6 +110,7 @@ concatenated. @@depthwise_conv2d_native @@separable_conv2d @@atrous_conv2d +@@atrous_conv2d_transpose @@conv2d_transpose @@conv1d @@conv3d diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py index 35610cc554..b2c6cf7138 100644 --- a/tensorflow/python/ops/nn_ops.py +++ b/tensorflow/python/ops/nn_ops.py @@ -1078,6 +1078,151 @@ def conv2d_transpose(value, name=name) +def atrous_conv2d_transpose(value, + filters, + output_shape, + rate, + padding, + name=None): + """The transpose of `atrous_conv2d`. + + This operation is sometimes called "deconvolution" after [Deconvolutional + Networks](http://www.matthewzeiler.com/pubs/cvpr2010/cvpr2010.pdf), but is + actually the transpose (gradient) of `atrous_conv2d` rather than an actual + deconvolution. + + Args: + value: A 4-D `Tensor` of type `float`. It needs to be in the default `NHWC` + format. Its shape is `[batch, in_height, in_width, in_channels]`. + filters: A 4-D `Tensor` with the same type as `value` and shape + `[filter_height, filter_width, out_channels, in_channels]`. `filters`' + `in_channels` dimension must match that of `value`. Atrous convolution is + equivalent to standard convolution with upsampled filters with effective + height `filter_height + (filter_height - 1) * (rate - 1)` and effective + width `filter_width + (filter_width - 1) * (rate - 1)`, produced by + inserting `rate - 1` zeros along consecutive elements across the + `filters`' spatial dimensions. + output_shape: A 1-D `Tensor` of shape representing the output shape of the + deconvolution op. + rate: A positive int32. The stride with which we sample input values across + the `height` and `width` dimensions. Equivalently, the rate by which we + upsample the filter values by inserting zeros across the `height` and + `width` dimensions. In the literature, the same parameter is sometimes + called `input stride` or `dilation`. + padding: A string, either `'VALID'` or `'SAME'`. The padding algorithm. + name: Optional name for the returned tensor. + + Returns: + A `Tensor` with the same type as `value`. + + Raises: + ValueError: If input/output depth does not match `filters`' shape, or if + padding is other than `'VALID'` or `'SAME'`, or if the `rate` is less + than one, or if the output_shape is not a tensor with 4 elements. + """ + with ops.name_scope(name, "atrous_conv2d_transpose", + [value, filters, output_shape]) as name: + value = ops.convert_to_tensor(value, name="value") + filters = ops.convert_to_tensor(filters, name="filters") + if not value.get_shape()[3].is_compatible_with(filters.get_shape()[3]): + raise ValueError( + "value's input channels does not match filters' input channels, " + "{} != {}".format(value.get_shape()[3], filters.get_shape()[3])) + if rate < 1: + raise ValueError("rate {} cannot be less than one".format(rate)) + + if rate == 1: + return conv2d_transpose(value, + filters, + output_shape, + strides=[1, 1, 1, 1], + padding=padding, + data_format="NHWC") + + output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape") + if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(4)): + raise ValueError("output_shape must have shape (4,), got {}" + .format(output_shape_.get_shape())) + + if isinstance(output_shape, (list, np.ndarray)): + # output_shape's shape should be == [4] if reached this point. + if not filters.get_shape()[2].is_compatible_with(output_shape[3]): + raise ValueError( + "output_shape does not match filter's output channels, " + "{} != {}".format(output_shape[3], filters.get_shape()[2])) + + # We have two padding contributions. The first is used for converting "SAME" + # to "VALID". The second is required so that the height and width of the + # zero-padded value tensor are multiples of rate. + + # Padding required to reduce to "VALID" convolution + if padding == "SAME": + # Handle filters whose shape is unknown during graph creation. + if filters.get_shape().is_fully_defined(): + filter_shape = filters.get_shape().as_list() + else: + filter_shape = array_ops.shape(filters) + filter_height, filter_width = filter_shape[0], filter_shape[1] + + # Spatial dimensions of the filters and the upsampled filters in which we + # introduce (rate - 1) zeros between consecutive filter values. + filter_height_up = filter_height + (filter_height - 1) * (rate - 1) + filter_width_up = filter_width + (filter_width - 1) * (rate - 1) + + pad_height = filter_height_up - 1 + pad_width = filter_width_up - 1 + + # When pad_height (pad_width) is odd, we pad more to bottom (right), + # following the same convention as conv2d(). + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top + pad_left = pad_width // 2 + pad_right = pad_width - pad_left + elif padding == "VALID": + pad_top = 0 + pad_bottom = 0 + pad_left = 0 + pad_right = 0 + else: + raise ValueError("padding must be either VALID or SAME:" + " {}".format(padding)) + + in_height = output_shape[1] + pad_top + pad_bottom + in_width = output_shape[2] + pad_left + pad_right + + # More padding so that rate divides the height and width of the input. + pad_bottom_extra = (rate - in_height % rate) % rate + pad_right_extra = (rate - in_width % rate) % rate + + # The paddings argument to space_to_batch is just the extra padding + # component. + space_to_batch_pad = [[0, pad_bottom_extra], [0, pad_right_extra]] + + value = array_ops.space_to_batch(input=value, + paddings=space_to_batch_pad, + block_size=rate) + + input_sizes = [rate * rate * output_shape[0], + (in_height + pad_bottom_extra) // rate, + (in_width + pad_right_extra) // rate, + output_shape[3]] + + value = gen_nn_ops.conv2d_backprop_input(input_sizes=input_sizes, + filter=filters, + out_backprop=value, + strides=[1, 1, 1, 1], + padding="VALID", + data_format="NHWC") + + # The crops argument to batch_to_space includes both padding components. + batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra], + [pad_left, pad_right + pad_right_extra]] + + return array_ops.batch_to_space(input=value, + crops=batch_to_space_crop, + block_size=rate) + + def conv3d_transpose(value, filter, output_shape, |