aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-01 11:08:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-01 12:17:01 -0700
commit19ad04564a70ae0134c044666460f47714e287f1 (patch)
treedae711833b192eaee432cc30d790602a119d146c
parentb1d9ef53ad6fbf1d98374471456040aecc0b4799 (diff)
TensorFlow: Add Conv3D/MaxPool3D/AvgPool3D C++ shape inference functions .
Change: 129011665
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc141
-rw-r--r--tensorflow/core/framework/common_shape_fns.h6
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc69
-rw-r--r--tensorflow/core/ops/nn_ops.cc3
4 files changed, 219 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index eea3112b3f..65cfb1a90e 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -263,6 +263,75 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+Status Conv3DShape(shape_inference::InferenceContext* c) {
+ const Shape* input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+ const Shape* filter_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &filter_shape));
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "Conv3D requires the stride attribute to contain 5 values, but got: ",
+ strides.size());
+ }
+
+ int32 stride_planes = strides[1];
+ int32 stride_rows = strides[2];
+ int32 stride_cols = strides[3];
+
+ const Dimension* batch_size_dim = c->Dim(input_shape, 0);
+ const Dimension* in_planes_dim = c->Dim(input_shape, 1);
+ const Dimension* in_rows_dim = c->Dim(input_shape, 2);
+ const Dimension* in_cols_dim = c->Dim(input_shape, 3);
+
+ const Dimension* filter_planes_dim = c->Dim(filter_shape, 0);
+ const Dimension* filter_rows_dim = c->Dim(filter_shape, 1);
+ const Dimension* filter_cols_dim = c->Dim(filter_shape, 2);
+ const Dimension* output_depth_dim = c->Dim(filter_shape, 4);
+
+ // At the moment we need to know the values of several fields.
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_planes_dim, "filter_planes"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_rows_dim, "filter_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, filter_cols_dim, "filter_cols"));
+
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+ auto filter_planes = c->Value(filter_planes_dim);
+ auto filter_rows = c->Value(filter_rows_dim);
+ auto filter_cols = c->Value(filter_cols_dim);
+
+ const Dimension* unused;
+ TF_RETURN_IF_ERROR(
+ c->Merge(c->Dim(input_shape, 4), c->Dim(filter_shape, 3), &unused));
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, filter_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, filter_rows, stride_rows, padding, &output_rows, &padding_before,
+ &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, filter_cols, stride_cols, padding, &output_cols, &padding_before,
+ &padding_after));
+
+ const Shape* output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+ c->set_output(0, output_shape);
+ return Status::OK();
+}
+
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c) {
const Shape* input_shape;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
@@ -507,6 +576,78 @@ Status MaxPoolShape(shape_inference::InferenceContext* c) {
return Status::OK();
}
+Status Pool3DShape(shape_inference::InferenceContext* c) {
+ const Shape* input_shape;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 5, &input_shape));
+
+ std::vector<int32> strides;
+ TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
+ if (strides.size() != 5) {
+ return errors::InvalidArgument(
+ "Pool3D ops require the stride attribute to contain 5 values, but "
+ "got: ",
+ strides.size());
+ }
+
+ std::vector<int32> kernel_sizes;
+ TF_RETURN_IF_ERROR(c->GetAttr("ksize", &kernel_sizes));
+ if (kernel_sizes.size() != 5) {
+ return errors::InvalidArgument(
+ "Pool3D requires the ksize attribute to contain 5 values, but got: ",
+ kernel_sizes.size());
+ }
+
+ int32 stride_planes, stride_rows, stride_cols;
+ int32 kernel_planes, kernel_rows, kernel_cols;
+
+ stride_planes = strides[1];
+ stride_rows = strides[2];
+ stride_cols = strides[3];
+ kernel_planes = kernel_sizes[1];
+ kernel_rows = kernel_sizes[2];
+ kernel_cols = kernel_sizes[3];
+
+ const Dimension* batch_size_dim = c->Dim(input_shape, 0);
+ const Dimension* in_planes_dim = c->Dim(input_shape, 1);
+ const Dimension* in_rows_dim = c->Dim(input_shape, 2);
+ const Dimension* in_cols_dim = c->Dim(input_shape, 3);
+ const Dimension* output_depth_dim = c->Dim(input_shape, 4);
+
+ // At the moment we need to know the values of several fields.
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_planes_dim, "in_planes"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_rows_dim, "in_rows"));
+ TF_RETURN_IF_ERROR(CheckKnownDim(c, in_cols_dim, "in_cols"));
+
+ Padding padding;
+ TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
+
+ // TODO(mrry,shlens): Raise an error if the stride would cause
+ // information in the input to be ignored. This will require a change
+ // in the kernel implementation.
+ auto in_planes = c->Value(in_planes_dim);
+ auto in_rows = c->Value(in_rows_dim);
+ auto in_cols = c->Value(in_cols_dim);
+
+ int64 output_planes, output_rows, output_cols;
+ int64 padding_before, padding_after;
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_planes, kernel_planes, stride_planes, padding, &output_planes,
+ &padding_before, &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_rows, kernel_rows, stride_rows, padding, &output_rows, &padding_before,
+ &padding_after));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
+ in_cols, kernel_cols, stride_cols, padding, &output_cols, &padding_before,
+ &padding_after));
+
+ const Shape* output_shape =
+ c->MakeShape({batch_size_dim, output_planes, output_rows, output_cols,
+ output_depth_dim});
+
+ c->set_output(0, output_shape);
+ return Status::OK();
+}
+
Status UnknownShape(shape_inference::InferenceContext* c) {
for (int i = 0; i < c->num_outputs(); ++i) {
c->set_output(i, c->UnknownShape());
diff --git a/tensorflow/core/framework/common_shape_fns.h b/tensorflow/core/framework/common_shape_fns.h
index f1bdd5ee8d..0ca6499036 100644
--- a/tensorflow/core/framework/common_shape_fns.h
+++ b/tensorflow/core/framework/common_shape_fns.h
@@ -157,6 +157,9 @@ Status BiasAddGradShape(shape_inference::InferenceContext* c);
// Shape function for Conv2D-like operations.
Status Conv2DShape(shape_inference::InferenceContext* c);
+// Shape function for Conv3D-like operations.
+Status Conv3DShape(shape_inference::InferenceContext* c);
+
// Shape function for DepthwiseConv2D-like operations.
Status DepthwiseConv2DNativeShape(shape_inference::InferenceContext* c);
@@ -166,6 +169,9 @@ Status AvgPoolShape(shape_inference::InferenceContext* c);
// Shape function for MaxPool-like operations.
Status MaxPoolShape(shape_inference::InferenceContext* c);
+// Shape function for 3D Pooling operations.
+Status Pool3DShape(shape_inference::InferenceContext* c);
+
// Shape function for use with ops whose output shapes are unknown.
Status UnknownShape(shape_inference::InferenceContext* c);
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index eada469b17..6e0dd7f742 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -419,6 +419,55 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]");
}
+TEST(CommonShapeFnsTest, Conv3DShapeTest) {
+ ShapeInferenceTestOp op("Conv3D");
+ auto set_op = [&op](const std::vector<int32>& strides,
+ const string& padding) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv3D")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("strides", strides)
+ .Attr("padding", padding)
+ .Finalize(&op.node_def));
+ };
+
+ // 1x1x1 filter
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
+
+ // Invalid rank for input
+ INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]");
+ // Invalid rank for filter
+ INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]");
+
+ // No unknown dims in the critical fields.
+ INFER_ERROR("is not known", op, "[1,?,2,2,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,?,2,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,?,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,2,1];[?,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,2,1];[1,?,1,1,1]");
+
+ // input depths must match.
+ INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
+ "[1,2,2,2,10];[1,1,1,10000,20]");
+
+ // 2x2x2 filter
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]");
+
+ // 3x3 input, 1x1 filter, 2x2 stride
+ set_op({{1, 2, 2, 2, 1}}, "VALID");
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
+
+ // 3x3 input, 1x1 filter, 2x1x1 stride
+ set_op({{1, 2, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]");
+
+ // 4x4 input, 2x2 filter, 1x1 stride
+ set_op({{1, 1, 1, 1, 1}}, "SAME");
+ INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,4,4,4,d1_4]");
+}
+
TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
ShapeInferenceTestOp op("DepthwiseConv2dNative");
std::vector<int32> strides = {{1, 1, 1, 1}};
@@ -512,6 +561,26 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]");
}
+TEST(CommonShapeFnsTest, Pool3DShapeTest) {
+ ShapeInferenceTestOp op("MaxPool3D");
+ auto set_op = [&op](const std::vector<int32>& strides,
+ const std::vector<int32>& ksizes, const string& padding) {
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D")
+ .Input("input", 0, DT_FLOAT)
+ .Attr("strides", strides)
+ .Attr("ksize", ksizes)
+ .Attr("padding", padding)
+ .Finalize(&op.node_def));
+ };
+
+ // Most of the functionality is tested by conv-like shapes,
+ // so we check that we handle the extra dimension properly.
+
+ // 2x3x4 stride, 1x1x1 filter.
+ set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
+ INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
+}
+
TEST(CommonShapeFnsTest, UnknownShapeTest) {
{
// Single output
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 03ada87511..3a2c02bd85 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -533,6 +533,7 @@ REGISTER_OP("Conv3D")
.Attr("T: numbertype")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
+ .SetShapeFn(shape_inference::Conv3DShape)
.Doc(R"doc(
Computes a 3-D convolution given 5-D `input` and `filter` tensors.
@@ -677,6 +678,7 @@ REGISTER_OP("AvgPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr("T: numbertype")
+ .SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D average pooling on the input.
@@ -726,6 +728,7 @@ REGISTER_OP("MaxPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr("T: numbertype")
+ .SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D max pooling on the input.