aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/atrous_conv2d_test.py
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2016-12-01 11:59:22 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-12-01 12:05:34 -0800
commit434794582e79e2d98d984a00a5779a712a34e885 (patch)
tree10a729db3739470f48d827ec43b94a18274f7af8 /tensorflow/python/kernel_tests/atrous_conv2d_test.py
parent431164534d382dda73581f22fb2a699cdba0b54f (diff)
Implement tf.nn.atrous_conv2d_transpose. Close bugs #4668 and #5300.
Change: 140759688
Diffstat (limited to 'tensorflow/python/kernel_tests/atrous_conv2d_test.py')
-rw-r--r--tensorflow/python/kernel_tests/atrous_conv2d_test.py92
1 files changed, 64 insertions, 28 deletions
diff --git a/tensorflow/python/kernel_tests/atrous_conv2d_test.py b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
index 1dff6a9f72..162bebf8d6 100644
--- a/tensorflow/python/kernel_tests/atrous_conv2d_test.py
+++ b/tensorflow/python/kernel_tests/atrous_conv2d_test.py
@@ -22,33 +22,33 @@ import numpy as np
import tensorflow as tf
-class AtrousConv2DTest(tf.test.TestCase):
-
- def _upsample_filters(self, filters, rate):
- """Upsamples the filters by a factor of rate along the spatial dimensions.
+def _upsample_filters(filters, rate):
+ """Upsamples the filters by a factor of rate along the spatial dimensions.
+
+ Args:
+ filters: [h, w, in_depth, out_depth]. Original filters.
+ rate: An int, specifying the upsampling rate.
+
+ Returns:
+ filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
+ h_up = h + (h - 1) * (rate - 1)
+ w_up = w + (w - 1) * (rate - 1)
+ containing (rate - 1) zeros between consecutive filter values along
+ the filters' spatial dimensions.
+ """
+ if rate == 1:
+ return filters
+ # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
+ filters_up = np.transpose(filters, [2, 3, 0, 1])
+ ker = np.zeros([rate, rate], dtype=np.float32)
+ ker[0, 0] = 1
+ filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
+ # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
+ filters_up = np.transpose(filters_up, [2, 3, 0, 1])
+ return filters_up
- Args:
- filters: [h, w, in_depth, out_depth]. Original filters.
- rate: An int, specifying the upsampling rate.
- Returns:
- filters_up: [h_up, w_up, in_depth, out_depth]. Upsampled filters with
- h_up = h + (h - 1) * (rate - 1)
- w_up = w + (w - 1) * (rate - 1)
- containing (rate - 1) zeros between consecutive filter values along
- the filters' spatial dimensions.
- """
- if rate == 1:
- return filters
- # [h, w, in_depth, out_depth] -> [in_depth, out_depth, h, w]
- filters_up = np.transpose(filters, [2, 3, 0, 1])
- ker = np.zeros([rate, rate])
- ker[0, 0] = 1
- filters_up = np.kron(filters_up, ker)[:, :, :-(rate-1), :-(rate-1)]
- # [in_depth, out_depth, h_up, w_up] -> [h_up, w_up, in_depth, out_depth]
- filters_up = np.transpose(filters_up, [2, 3, 0, 1])
- self.assertEqual(np.sum(filters), np.sum(filters_up))
- return filters_up
+class AtrousConv2DTest(tf.test.TestCase):
def testAtrousConv2DForward(self):
with self.test_session(use_gpu=True):
@@ -65,14 +65,13 @@ class AtrousConv2DTest(tf.test.TestCase):
f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)
for rate in range(1, 4):
- f_up = self._upsample_filters(f, rate)
+ f_up = _upsample_filters(f, rate)
for padding in ["SAME", "VALID"]:
y1 = tf.nn.atrous_conv2d(x, f, rate, padding=padding)
y2 = tf.nn.conv2d(x, f_up, strides=[1, 1, 1, 1],
padding=padding)
- self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-2,
- atol=1e-2)
+ self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)
def testAtrousSequence(self):
"""Tests optimization of sequence of atrous convolutions.
@@ -150,5 +149,42 @@ class AtrousConv2DTest(tf.test.TestCase):
self.assertLess(err, err_tolerance)
+class AtrousConv2DTransposeTest(tf.test.TestCase):
+
+ def testAtrousConv2DTransposeForward(self):
+ with self.test_session(use_gpu=True):
+ # Input: [batch, height, width, input_depth]
+ height = 9
+ for width in [9, 10]: # Test both odd and even width.
+ x_shape = [2, height, width, 2]
+ x = np.arange(np.prod(x_shape), dtype=np.float32).reshape(x_shape)
+
+ # Filter: [kernel_height, kernel_width, input_depth, output_depth]
+ for kernel_height in range(1, 4):
+ for kernel_width in range(1, 4):
+ f_shape = [kernel_height, kernel_width, 2, 2]
+ f = np.arange(np.prod(f_shape), dtype=np.float32).reshape(f_shape)
+
+ for rate in range(1, 4):
+ f_up = _upsample_filters(f, rate)
+ kernel_height_up = (kernel_height +
+ (kernel_height - 1) * (rate - 1))
+ kernel_width_up = kernel_width + (kernel_width - 1) * (rate - 1)
+
+ for padding in ["SAME", "VALID"]:
+ if padding == "SAME":
+ y_shape = [2, height, width, 2]
+ else:
+ y_shape = [2,
+ height + kernel_height_up - 1,
+ width + kernel_width_up - 1,
+ 2]
+
+ y1 = tf.nn.atrous_conv2d_transpose(x, f, y_shape, rate, padding)
+ y2 = tf.nn.conv2d_transpose(
+ x, f_up, y_shape, strides=[1, 1, 1, 1], padding=padding)
+ self.assertAllClose(y1.eval(), y2.eval(), rtol=1e-3, atol=1e-3)
+
+
if __name__ == "__main__":
tf.test.main()