aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python/kernel_tests/conv_ops_3d_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python/kernel_tests/conv_ops_3d_test.py')
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py3
1 files changed, 3 insertions, 0 deletions
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index ec8ac74163..f4616fd661 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -344,6 +345,8 @@ class Conv3DTest(test.TestCase):
if data_format == "NCDHW":
conv = test_util.NCHWToNHWC(conv)
+ self.assertEqual(conv.shape, tensor_shape.TensorShape(output_shape))
+
if test_input:
jacob_t, jacob_n = gradient_checker.compute_gradient(
orig_input_tensor, input_shape, conv, output_shape)