aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/conv1d_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/conv1d_test.py')
-rw-r--r--tensorflow/python/kernel_tests/conv1d_test.py43
1 files changed, 43 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py
index 662c94eea7..7c8d309bbd 100644
--- a/tensorflow/python/kernel_tests/conv1d_test.py
+++ b/tensorflow/python/kernel_tests/conv1d_test.py
@@ -17,6 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
@@ -50,5 +53,45 @@ class Conv1DTest(test.TestCase):
self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4])
+ def testConv1DTranspose(self):
+ with self.test_session():
+ stride = 2
+
+ # Input, output: [batch, width, depth]
+ x_shape = [2, 4, 3]
+ y_shape = [2, 9, 2]
+
+ # Filter: [kernel_width, output_depth, input_depth]
+ f_shape = [3, 2, 3]
+
+ x = constant_op.constant(
+ 1.0, shape=x_shape, name="x", dtype=dtypes.float32)
+ f = constant_op.constant(
+ 1.0, shape=f_shape, name="filter", dtype=dtypes.float32)
+ output = nn_ops.conv1d_transpose(
+ x, f, y_shape, stride=stride, padding="VALID")
+ value = output.eval()
+
+ cache_values = np.zeros(y_shape, dtype=np.float32)
+
+ # The amount of padding added
+ pad = 1
+
+ for n in xrange(x_shape[0]):
+ for k in xrange(f_shape[1]):
+ for w in xrange(pad, y_shape[1] - pad):
+ target = 3.0
+ # We add a case for locations divisible by the stride.
+ w_in = w % stride == 0 and w > pad and w < y_shape[1] - 1 - pad
+ if w_in:
+ target += 3.0
+ cache_values[n, w, k] = target
+
+ # copy values in the border
+ cache_values[n, 0, k] = cache_values[n, 1, k]
+ cache_values[n, -1, k] = cache_values[n, -2, k]
+
+ self.assertAllClose(cache_values, value)
+
if __name__ == "__main__":
test.main()