diff options
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 4 | ||||
-rw-r--r-- | tensorflow/python/kernel_tests/conv_ops_3d_test.py | 3 |
2 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 2fb17c2b02..72eeda7a43 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -504,8 +504,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) { input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}}); stride_planes = strides[2]; - stride_cols = strides[3]; - stride_rows = strides[4]; + stride_rows = strides[3]; + stride_cols = strides[4]; } else { stride_planes = strides[1]; stride_rows = strides[2]; diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py index ec8ac74163..f4616fd661 100644 --- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py +++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py @@ -25,6 +25,7 @@ import numpy as np from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import test_util from tensorflow.python.ops import gradient_checker from tensorflow.python.ops import nn_ops @@ -344,6 +345,8 @@ class Conv3DTest(test.TestCase): if data_format == "NCDHW": conv = test_util.NCHWToNHWC(conv) + self.assertEqual(conv.shape, tensor_shape.TensorShape(output_shape)) + if test_input: jacob_t, jacob_n = gradient_checker.compute_gradient( orig_input_tensor, input_shape, conv, output_shape) |