diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-11-23 09:49:33 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-11-23 10:07:59 -0800 |
commit | 51ae55c4e782c7577e10e51ee54721593204f396 (patch) | |
tree | a961f60a3f2e468b8b9a27b3ce1cbbcc2ada7ab7 | |
parent | 542c0d6fd433d50f6558bfcc1ed2cab3cc52a4e5 (diff) |
Relax Pool3D shape function when some dimensions are unknown.
Added test.
Fixes #5807.
Change: 140042698
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 11 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 4 |
2 files changed, 4 insertions, 11 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 794f9c37cf..dcf0ae40d5 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -557,17 +557,6 @@ Status Pool3DShape(shape_inference::InferenceContext* c) { DimensionHandle in_cols_dim = c->Dim(input_shape, 3); DimensionHandle output_depth_dim = c->Dim(input_shape, 4); - // At the moment we need to know the values of several fields. - if (!c->ValueKnown(in_planes_dim) || !c->ValueKnown(in_rows_dim) || - !c->ValueKnown(in_cols_dim)) { - ShapeHandle output_shape = - c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, - InferenceContext::kUnknownDim, output_depth_dim}); - c->set_output(0, output_shape); - return Status::OK(); - } - Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 2be771e3a9..972d05dd5d 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -699,6 +699,10 @@ TEST(CommonShapeFnsTest, Pool3DShapeTest) { // 2x3x4 stride, 1x1x1 filter. set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]"); + + // Test partially known dimensions + set_op({1, 1, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,?,24,24,1]", "[d0_0,?,8,6,d0_4]"); } TEST(CommonShapeFnsTest, UnknownShapeTest) { |