aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Reed Wanderman-Milne <reedwm@google.com>2018-04-03 17:26:16 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-03 17:29:06 -0700
commit4a20926c555a2c8da47f4138f436417bfe7db6d0 (patch)
tree89a491cfabc107d7c4daffbadb9a13343d5990d0
parent467f195a2dd87257e3719576637774ebcf7a4590 (diff)
Fix Conv3D shape inference.
Before, the stride rows and columns were mixed up, causing shape inference to output the wrong shape. PiperOrigin-RevId: 191525254
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc4
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_3d_test.py3
2 files changed, 5 insertions, 2 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 2fb17c2b02..72eeda7a43 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -504,8 +504,8 @@ Status Conv3DShape(shape_inference::InferenceContext* c) {
input_shape =
c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('2'), dim('C')}});
stride_planes = strides[2];
- stride_cols = strides[3];
- stride_rows = strides[4];
+ stride_rows = strides[3];
+ stride_cols = strides[4];
} else {
stride_planes = strides[1];
stride_rows = strides[2];
diff --git a/tensorflow/python/kernel_tests/conv_ops_3d_test.py b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
index ec8ac74163..f4616fd661 100644
--- a/tensorflow/python/kernel_tests/conv_ops_3d_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_3d_test.py
@@ -25,6 +25,7 @@ import numpy as np
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gradient_checker
from tensorflow.python.ops import nn_ops
@@ -344,6 +345,8 @@ class Conv3DTest(test.TestCase):
if data_format == "NCDHW":
conv = test_util.NCHWToNHWC(conv)
+ self.assertEqual(conv.shape, tensor_shape.TensorShape(output_shape))
+
if test_input:
jacob_t, jacob_n = gradient_checker.compute_gradient(
orig_input_tensor, input_shape, conv, output_shape)