aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/conv3d_transpose_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/conv3d_transpose_test.py')
-rw-r--r--tensorflow/python/kernel_tests/conv3d_transpose_test.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/conv3d_transpose_test.py b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
index a8b3af5096..8973a450fa 100644
--- a/tensorflow/python/kernel_tests/conv3d_transpose_test.py
+++ b/tensorflow/python/kernel_tests/conv3d_transpose_test.py
@@ -119,6 +119,18 @@ class Conv3DTransposeTest(test.TestCase):
target = 3.0
self.assertAllClose(target, value[n, d, h, w, k])
+ def testConv3DTransposeShapeMismatch(self):
+ # Test case for GitHub issue 18460
+ x_shape = [2, 2, 3, 4, 3]
+ f_shape = [3, 3, 3, 2, 2]
+ y_shape = [2, 2, 6, 8, 6]
+ strides = [1, 1, 2, 2, 2]
+ np.random.seed(1)
+ x_value = np.random.random_sample(x_shape).astype(np.float64)
+ f_value = np.random.random_sample(f_shape).astype(np.float64)
+ nn_ops.conv3d_transpose(
+ x_value, f_value, y_shape, strides, data_format='NCDHW')
+
def testConv3DTransposeValid(self):
with self.test_session():
strides = [1, 2, 2, 2, 1]