diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-08-01 11:08:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-01 12:17:01 -0700 |
commit | 19ad04564a70ae0134c044666460f47714e287f1 (patch) | |
tree | dae711833b192eaee432cc30d790602a119d146c | |
parent | b1d9ef53ad6fbf1d98374471456040aecc0b4799 (diff) |
TensorFlow: Add Conv3D/MaxPool3D/AvgPool3D C++ shape inference functions .
Change: 129011665
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.cc | 141 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns.h | 6 | ||||
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 69 | ||||
-rw-r--r-- | tensorflow/core/ops/nn_ops.cc | 3 |
4 files changed, 219 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc index eea3112b3f..65cfb1a90e 100644 --- a/tensorflow/core/framework/common_shape_fns.cc +++ b/tensorflow/core/framework/common_shape_fns.cc @@ -263,6 +263,75 @@ Status Conv2DShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status Conv3DShape(shape_inference::InferenceContext* c) { + const Shape* input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + const Shape* filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape)); + + std::vector<int32> strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "Conv3D requires the stride attribute to contain 5 values, but got: ", + strides.size()); + } + + int32 stride_planes = strides[1]; + int32 stride_rows = strides[2]; + int32 stride_cols = strides[3]; + + const Dimension* batch_size_dim = c->Dim(input_shape, 0); + const Dimension* in_planes_dim = c->Dim(input_shape, 1); + const Dimension* in_rows_dim = c->Dim(input_shape, 2); + const Dimension* in_cols_dim = c->Dim(input_shape, 3); + + const Dimension* filter_planes_dim = c->Dim(filter_shape, 0); + const Dimension* filter_rows_dim = c->Dim(filter_shape, 1); + const Dimension* filter_cols_dim = c->Dim(filter_shape, 2); + const Dimension* output_depth_dim = c->Dim(filter_shape, 4); + + // At the moment we need to know the values of several fields. + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_planes_dim, "filter_planes")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols")); + + auto in_planes = c->Value(in_planes_dim); + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + auto filter_planes = c->Value(filter_planes_dim); + auto filter_rows = c->Value(filter_rows_dim); + auto filter_cols = c->Value(filter_cols_dim); + + const Dimension* unused; + TF_RETURN_IF_ERROR( + c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused)); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + int64 output_planes, output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_planes, filter_planes, stride_planes, padding, &output_planes, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, filter_rows, stride_rows, padding, &output_rows, &padding_before, + &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before, + &padding_after)); + + const Shape* output_shape = + c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, + output_depth_dim}); + c->set_output(0, output_shape); + return Status::OK(); +} + Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) { const Shape* input_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape)); @@ -507,6 +576,78 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) { return Status::OK(); } +Status Pool3DShape(shape_inference::InferenceContext* c) { + const Shape* input_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape)); + + std::vector<int32> strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 5) { + return errors::InvalidArgument( + "Pool3D ops require the stride attribute to contain 5 values, but " + "got: ", + strides.size()); + } + + std::vector<int32> kernel_sizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes)); + if (kernel_sizes.size() != 5) { + return errors::InvalidArgument( + "Pool3D requires the ksize attribute to contain 5 values, but got: ", + kernel_sizes.size()); + } + + int32 stride_planes, stride_rows, stride_cols; + int32 kernel_planes, kernel_rows, kernel_cols; + + stride_planes = strides[1]; + stride_rows = strides[2]; + stride_cols = strides[3]; + kernel_planes = kernel_sizes[1]; + kernel_rows = kernel_sizes[2]; + kernel_cols = kernel_sizes[3]; + + const Dimension* batch_size_dim = c->Dim(input_shape, 0); + const Dimension* in_planes_dim = c->Dim(input_shape, 1); + const Dimension* in_rows_dim = c->Dim(input_shape, 2); + const Dimension* in_cols_dim = c->Dim(input_shape, 3); + const Dimension* output_depth_dim = c->Dim(input_shape, 4); + + // At the moment we need to know the values of several fields. + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows")); + TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols")); + + Padding padding; + TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding)); + + // TODO(mrry,shlens): Raise an error if the stride would cause + // information in the input to be ignored. This will require a change + // in the kernel implementation. + auto in_planes = c->Value(in_planes_dim); + auto in_rows = c->Value(in_rows_dim); + auto in_cols = c->Value(in_cols_dim); + + int64 output_planes, output_rows, output_cols; + int64 padding_before, padding_after; + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_planes, kernel_planes, stride_planes, padding, &output_planes, + &padding_before, &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_rows, kernel_rows, stride_rows, padding, &output_rows, &padding_before, + &padding_after)); + TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose( + in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before, + &padding_after)); + + const Shape* output_shape = + c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols, + output_depth_dim}); + + c->set_output(0, output_shape); + return Status::OK(); +} + Status UnknownShape(shape_inference::InferenceContext* c) { for (int i = 0; i < c->num_outputs(); ++i) { c->set_output(i, c->UnknownShape()); diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h index f1bdd5ee8d..0ca6499036 100644 --- a/tensorflow/core/framework/common_shape_fns.h +++ b/tensorflow/core/framework/common_shape_fns.h @@ -157,6 +157,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c); // Shape function for Conv2D-like operations. Status Conv2DShape(shape_inference::InferenceContext* c); +// Shape function for Conv3D-like operations. +Status Conv3DShape(shape_inference::InferenceContext* c); + // Shape function for DepthwiseConv2D-like operations. Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c); @@ -166,6 +169,9 @@ Status AvgPoolShape(shape_inference::InferenceContext* c); // Shape function for MaxPool-like operations. Status MaxPoolShape(shape_inference::InferenceContext* c); +// Shape function for 3D Pooling operations. +Status Pool3DShape(shape_inference::InferenceContext* c); + // Shape function for use with ops whose output shapes are unknown. Status UnknownShape(shape_inference::InferenceContext* c); diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index eada469b17..6e0dd7f742 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -419,6 +419,55 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]"); } +TEST(CommonShapeFnsTest, Conv3DShapeTest) { + ShapeInferenceTestOp op("Conv3D"); + auto set_op = [&op](const std::vector<int32>& strides, + const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv3D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // 1x1x1 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // Invalid rank for input + INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]"); + // Invalid rank for filter + INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]"); + + // No unknown dims in the critical fields. + INFER_ERROR("is not known", op, "[1,?,2,2,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,?,2,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,?,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,2,1];[?,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,2,1];[1,?,1,1,1]"); + + // input depths must match. + INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op, + "[1,2,2,2,10];[1,1,1,10000,20]"); + + // 2x2x2 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 2, 2, 2, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // 3x3 input, 1x1 filter, 2x1x1 stride + set_op({{1, 2, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]"); + + // 4x4 input, 2x2 filter, 1x1 stride + set_op({{1, 1, 1, 1, 1}}, "SAME"); + INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,4,4,4,d1_4]"); +} + TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { ShapeInferenceTestOp op("DepthwiseConv2dNative"); std::vector<int32> strides = {{1, 1, 1, 1}}; @@ -512,6 +561,26 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]"); } +TEST(CommonShapeFnsTest, Pool3DShapeTest) { + ShapeInferenceTestOp op("MaxPool3D"); + auto set_op = [&op](const std::vector<int32>& strides, + const std::vector<int32>& ksizes, const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check that we handle the extra dimension properly. + + // 2x3x4 stride, 1x1x1 filter. + set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]"); +} + TEST(CommonShapeFnsTest, UnknownShapeTest) { { // Single output diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc index 03ada87511..3a2c02bd85 100644 --- a/tensorflow/core/ops/nn_ops.cc +++ b/tensorflow/core/ops/nn_ops.cc @@ -533,6 +533,7 @@ REGISTER_OP("Conv3D") .Attr("T: numbertype") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) + .SetShapeFn(shape_inference::Conv3DShape) .Doc(R"doc( Computes a 3-D convolution given 5-D `input` and `filter` tensors. @@ -677,6 +678,7 @@ REGISTER_OP("AvgPool3D") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr("T: numbertype") + .SetShapeFn(shape_inference::Pool3DShape) .Doc(R"doc( Performs 3D average pooling on the input. @@ -726,6 +728,7 @@ REGISTER_OP("MaxPool3D") .Attr("strides: list(int) >= 5") .Attr(GetPaddingAttrString()) .Attr("T: numbertype") + .SetShapeFn(shape_inference::Pool3DShape) .Doc(R"doc( Performs 3D max pooling on the input. |