aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc11
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc4
2 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index 794f9c37cf..dcf0ae40d5 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -557,17 +557,6 @@ Status Pool3DShape(shape_inference::InferenceContext* c) {
DimensionHandle in_cols_dim = c->Dim(input_shape, 3);
DimensionHandle output_depth_dim = c->Dim(input_shape, 4);
- // At the moment we need to know the values of several fields.
- if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) ||
- !c->ValueKnown(in_cols_dim)) {
- ShapeHandle output_shape =
- c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim,
- InferenceContext::kUnknownDim, output_depth_dim});
- c->set_output(0, output_shape);
- return Status::OK();
- }
-
Padding padding;
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index 2be771e3a9..972d05dd5d 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -699,6 +699,10 @@ TEST(CommonShapeFnsTest, Pool3DShapeTest) {
// 2x3x4 stride, 1x1x1 filter.
set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
+
+ // Test partially known dimensions
+ set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
+ INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]");
}
TEST(CommonShapeFnsTest, UnknownShapeTest) {