diff options
author | 2017-09-06 09:42:57 -0700 | |
---|---|---|
committer | 2017-09-06 09:54:53 -0700 | |
commit | 10594900c5df1b84cd0336d0fb5bd0d8454bfe08 (patch) | |
tree | 7445e120556edba062cf82f3b0ade447e8cb8362 /tensorflow/core/framework | |
parent | b71c1bb6f5edd77f92a10e99d011a00de572aa68 (diff) |
Update MaxPoolV2Shape to support NCHV_VECT_C.
PiperOrigin-RevId: 167732437
Diffstat (limited to 'tensorflow/core/framework')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 88 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 45 |
2 files changed, 85 insertions, 48 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index 0e3ea2ddfb..2d44480053 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -559,9 +559,6 @@ Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { } Status AvgPoolShape(shape_inference::InferenceContext* c) { - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); - string data_format_str; TensorFormat data_format; Status s = c->GetAttr("data_format", &data_format_str); @@ -571,6 +568,10 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { data_format = FORMAT_NHWC; } + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); @@ -627,9 +628,6 @@ Status AvgPoolShape(shape_inference::InferenceContext* c) { } Status MaxPoolShape(shape_inference::InferenceContext* c) { - ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 4, &input_shape)); - string data_format_str; TensorFormat data_format; Status s = c->GetAttr("data_format", &data_format_str); @@ -639,6 +637,10 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { data_format = FORMAT_NHWC; } + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; + ShapeHandle input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); + TF_RETURN_IF_ERROR( CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); @@ -696,11 +698,21 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { } Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { + string data_format_str; + TensorFormat data_format; + Status s = c->GetAttr("data_format", &data_format_str); + if (s.ok()) { + FormatFromString(data_format_str, &data_format); + } else { + data_format = FORMAT_NHWC; + } + + const int rank = (data_format == FORMAT_NCHW_VECT_C) ? 5 : 4; ShapeHandle input_shape; - TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &input_shape)); - string data_format; - Status s = c->GetAttr("data_format", &data_format); + TF_RETURN_IF_ERROR( + CheckFormatConstraintsOnShape(data_format, input_shape, "input", c)); std::vector<int32> kernel_sizes; std::vector<int32> strides; @@ -725,7 +737,8 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { } kernel_sizes.resize(kernel_sizes_tensor->shape().num_elements()); auto kernel_sizes_vec = kernel_sizes_tensor->flat<int32>(); - std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), kernel_sizes.begin()); + std::copy_n(&kernel_sizes_vec(0), kernel_sizes.size(), + kernel_sizes.begin()); const Tensor* strides_tensor = c->input_tensor(c->num_inputs() - 1); if (strides_tensor == nullptr) { @@ -749,35 +762,22 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { kernel_sizes.size()); } - int32 stride_rows, stride_cols, stride_depth; - int32 kernel_rows, kernel_cols, kernel_depth; - - if (s.ok() && data_format == "NCHW") { - // Canonicalize input shape to NHWC so the shape inference code below can - // process it. - auto dim = [&](char dimension) { - return c->Dim(input_shape, GetTensorDimIndex<2>(FORMAT_NCHW, dimension)); - }; - input_shape = c->MakeShape({{dim('N'), dim('0'), dim('1'), dim('C')}}); - stride_depth = strides[1]; - stride_rows = strides[2]; - stride_cols = strides[3]; - kernel_depth = kernel_sizes[1]; - kernel_rows = kernel_sizes[2]; - kernel_cols = kernel_sizes[3]; - } else { - stride_rows = strides[1]; - stride_cols = strides[2]; - stride_depth = strides[3]; - kernel_rows = kernel_sizes[1]; - kernel_cols = kernel_sizes[2]; - kernel_depth = kernel_sizes[3]; - } + int32 stride_depth = GetTensorDim(strides, data_format, 'C'); + int32 stride_rows = GetTensorDim(strides, data_format, 'H'); + int32 stride_cols = GetTensorDim(strides, data_format, 'W'); + int32 kernel_depth = GetTensorDim(kernel_sizes, data_format, 'C'); + int32 kernel_rows = GetTensorDim(kernel_sizes, data_format, 'H'); + int32 kernel_cols = GetTensorDim(kernel_sizes, data_format, 'W'); - DimensionHandle batch_size_dim = c->Dim(input_shape, 0); - DimensionHandle in_rows_dim = c->Dim(input_shape, 1); - DimensionHandle in_cols_dim = c->Dim(input_shape, 2); - DimensionHandle in_depth_dim = c->Dim(input_shape, 3); + constexpr int num_spatial_dims = 2; + DimensionHandle batch_size_dim = c->Dim( + input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'N')); + DimensionHandle in_rows_dim = c->Dim( + input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'H')); + DimensionHandle in_cols_dim = c->Dim( + input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'W')); + DimensionHandle in_depth_dim = c->Dim( + input_shape, GetTensorDimIndex<num_spatial_dims>(data_format, 'C')); Padding padding; TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); @@ -791,15 +791,9 @@ Status MaxPoolV2Shape(shape_inference::InferenceContext* c, int num_inputs) { TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims( c, in_depth_dim, kernel_depth, stride_depth, padding, &output_depth)); - output_shape = - c->MakeShape({batch_size_dim, output_rows, output_cols, output_depth}); - if (data_format == "NCHW") { - // Convert output shape back to expected NCHW data format. - auto dim = [&](char dimension) { - return c->Dim(output_shape, GetTensorDimIndex<2>(FORMAT_NHWC, dimension)); - }; - output_shape = c->MakeShape({{dim('N'), dim('C'), dim('0'), dim('1')}}); - } + TF_RETURN_IF_ERROR(MakeShapeFromFormat(data_format, batch_size_dim, + {output_rows, output_cols}, + output_depth, &output_shape, c)); c->set_output(0, output_shape); return Status::OK(); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index 14f6c1bb45..90a48f14d4 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/fake_input.h" #include "tensorflow/core/framework/node_def_builder.h" #include "tensorflow/core/framework/op_def_builder.h" #include "tensorflow/core/framework/shape_inference_testutil.h" @@ -704,7 +705,7 @@ TEST(CommonShapeFnsTest, AvgPool2DShapeTest) { INFER_ERROR("Dimension must be 4 but is 3", op, "[2,5,7,11,3]"); // Invalid rank for input - INFER_ERROR("must be at least rank 4", op, "[4,4]"); + INFER_ERROR("Shape must be rank", op, "[4,4]"); } TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { @@ -741,6 +742,48 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8]"); } +TEST(CommonShapeFnsTest, MaxPoolV22DShapeTest) { + ShapeInferenceTestOp op("MaxPoolV2"); + Tensor ksizes_tensor, strides_tensor; + auto set_op = [&op, &ksizes_tensor, &strides_tensor]( + const std::vector<int32>& strides, + const std::vector<int32>& ksizes, const string& padding, + const string& data_format) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPoolV2") + .Input("input", 0, DT_FLOAT) + .Input("ksize", 1, DT_INT32) + .Input("strides", 2, DT_INT32) + .Attr("padding", padding) + .Attr("data_format", data_format) + .Finalize(&op.node_def)); + ksizes_tensor = test::AsTensor<int32>(ksizes); + op.input_tensors.resize(3); + op.input_tensors[0] = nullptr; + op.input_tensors[1] = &ksizes_tensor; + strides_tensor = test::AsTensor<int32>(strides); + op.input_tensors[2] = &strides_tensor; + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check the very-specific maxpooling features here, + // namely depthwise kernel and striding. + + // all 1 strides, depth 2 filter + set_op({1, 1, 1, 1}, {1, 1, 1, 2}, "VALID", "NHWC"); + INFER_OK(op, "[1,2,2,2];[4];[4]", "[d0_0,2,2,1]"); + + // depth 3 stride, 1x1x1 filter, NCHW + set_op({1, 3, 1, 1}, {1, 1, 1, 1}, "VALID", "NCHW"); + INFER_OK(op, "[1,7,5,5];[4];[4]", "[d0_0,3,5,5]"); + + // 5x7 input, 2x2 ksize, 1x1 stride, NCHW_VECT_C tests + set_op({{1, 1, 1, 1}}, {1, 1, 2, 2}, "SAME", "NCHW_VECT_C"); + INFER_OK(op, "[2,3,5,7,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[5,7,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_OK(op, "[?,?,?,?,4];[4];[4]", "[d0_0,d0_1,d0_2,d0_3,4]"); + INFER_ERROR("Dimension must be 4 but is 8", op, "[2,3,5,7,8];[4];[4]"); +} + TEST(CommonShapeFnsTest, Pool3DShapeTest) { ShapeInferenceTestOp op("MaxPool3D"); auto set_op = [&op](const std::vector<int32>& strides, |