aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-11-23 09:49:33 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-11-23 10:07:59 -0800
commit51ae55c4e782c7577e10e51ee54721593204f396 (patch)
treea961f60a3f2e468b8b9a27b3ce1cbbcc2ada7ab7
parent542c0d6fd433d50f6558bfcc1ed2cab3cc52a4e5 (diff)
Relax Pool3D shape function when some dimensions are unknown.
Added test. Fixes #5807. Change: 140042698
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc11
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc4
2 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 794f9c37cf..dcf0ae40d5 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -557,17 +557,6 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
- // At the moment we need to know the values of several fields.
- if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
- !c->ValueKnown(in_cols_dim)) {
- ShapeHandle output_shape =
- c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim, output_depth_dim});
- c->set_output(0, output_shape);
- return Status::OK();
- }
-
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 2be771e3a9..972d05dd5d 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -699,6 +699,10 @@ TEST(CommonShapeFnsTest, Pool3DShapeTest) {
// 2x3x4 stride, 1x1x1 filter.
set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
+
+ // Test partially known dimensions
+ set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
+ INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]");
}
TEST(CommonShapeFnsTest, UnknownShapeTest) {