diff options
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 4 |
2 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 794f9c37cf..dcf0ae40d5 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -557,17 +557,6 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { DimensionHandle in_cols_dim = c->Dim(input_shape, 3); DimensionHandle output_depth_dim = c->Dim(input_shape, 4); - // At the moment we need to know the values of several fields. - if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || - !c->ValueKnown(in_cols_dim)) { - ShapeHandle output_shape = - c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, output_depth_dim}); - c->set_output(0, output_shape); - return Status::OK(); - } - Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 2be771e3a9..972d05dd5d 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -699,6 +699,10 @@ TEST(CommonShapeFnsTest, Pool3DShapeTest) { // 2x3x4 stride, 1x1x1 filter. set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]"); + + // Test partially known dimensions + set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]"); } TEST(CommonShapeFnsTest, UnknownShapeTest) { |