diff options
Diffstat (limited to 'tensorflow/python/kernel_tests/conv1d_test.py')
-rw-r--r-- | tensorflow/python/kernel_tests/conv1d_test.py | 44 |
1 files changed, 23 insertions, 21 deletions
diff --git a/tensorflow/python/kernel_tests/conv1d_test.py b/tensorflow/python/kernel_tests/conv1d_test.py index d92797a7d3..e2e6205911 100644 --- a/tensorflow/python/kernel_tests/conv1d_test.py +++ b/tensorflow/python/kernel_tests/conv1d_test.py @@ -30,27 +30,29 @@ from tensorflow.python.platform import test class Conv1DTest(test.TestCase): def testBasic(self): - """Test that argument passing to conv2d is handled properly.""" - - x = constant_op.constant([1, 2, 3, 4], dtype=dtypes.float32) - x = array_ops.expand_dims(x, 0) # Add batch dimension - x = array_ops.expand_dims(x, 2) # And depth dimension - filters = constant_op.constant([2, 1], dtype=dtypes.float32) - filters = array_ops.expand_dims(filters, 1) # in_channels - filters = array_ops.expand_dims(filters, 2) # out_channels - # Filters is 2x1x1 - for stride in [1, 2]: - with self.test_session(use_gpu=test.is_gpu_available()): - c = nn_ops.conv1d(x, filters, stride, padding="VALID") - reduced = array_ops.squeeze(c) - output = reduced.eval() - if stride == 1: - self.assertEqual(len(output), 3) - self.assertAllClose(output, - [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]) - else: - self.assertEqual(len(output), 2) - self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) + """Test that argument passing to conv1d is handled properly.""" + # TODO(yongtang): dtypes.float64 can only be enabled once conv2d support + # dtypes.float64, as conv1d implicitly calls conv2d after expand_dims. + for dtype in [dtypes.float16, dtypes.float32]: + x = constant_op.constant([1, 2, 3, 4], dtype=dtype) + x = array_ops.expand_dims(x, 0) # Add batch dimension + x = array_ops.expand_dims(x, 2) # And depth dimension + filters = constant_op.constant([2, 1], dtype=dtype) + filters = array_ops.expand_dims(filters, 1) # in_channels + filters = array_ops.expand_dims(filters, 2) # out_channels + # Filters is 2x1x1 + for stride in [1, 2]: + with self.test_session(use_gpu=test.is_gpu_available()): + c = nn_ops.conv1d(x, filters, stride, padding="VALID") + reduced = array_ops.squeeze(c) + output = reduced.eval() + if stride == 1: + self.assertEqual(len(output), 3) + self.assertAllClose(output, + [2 * 1 + 1 * 2, 2 * 2 + 1 * 3, 2 * 3 + 1 * 4]) + else: + self.assertEqual(len(output), 2) + self.assertAllClose(output, [2 * 1 + 1 * 2, 2 * 3 + 1 * 4]) def testConv1DTranspose(self): with self.test_session(): |