aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/framework/common_shape_fns_test.cc
diff options
context:
space:
mode:
authorGravatar Vijay Vasudevan <vrv@google.com>2016-08-01 11:08:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2016-08-01 12:17:01 -0700
commit19ad04564a70ae0134c044666460f47714e287f1 (patch)
treedae711833b192eaee432cc30d790602a119d146c /tensorflow/core/framework/common_shape_fns_test.cc
parentb1d9ef53ad6fbf1d98374471456040aecc0b4799 (diff)
TensorFlow: Add Conv3D/MaxPool3D/AvgPool3D C++ shape inference functions .
Change: 129011665
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns_test.cc')
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc69
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index eada469b17..6e0dd7f742 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -419,6 +419,55 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]");
}
+TEST(CommonShapeFnsTest, Conv3DShapeTest) {
+ ShapeInferenceTestOp op("Conv3D");
+ auto set_op = [&op](const std::vector<int32>& strides,
+ const string& padding) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv3D")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("strides", strides)
+ .Attr("padding", padding)
+ .Finalize(&op.node_def));
+ };
+
+ // 1x1x1 filter
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
+
+ // Invalid rank for input
+ INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]");
+ // Invalid rank for filter
+ INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]");
+
+ // No unknown dims in the critical fields.
+ INFER_ERROR("is not known", op, "[1,?,2,2,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,?,2,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,?,1];[1,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,2,1];[?,1,1,1,1]");
+ INFER_ERROR("is not known", op, "[1,2,2,2,1];[1,?,1,1,1]");
+
+ // input depths must match.
+ INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op,
+ "[1,2,2,2,10];[1,1,1,10000,20]");
+
+ // 2x2x2 filter
+ set_op({{1, 1, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]");
+
+ // 3x3 input, 1x1 filter, 2x2 stride
+ set_op({{1, 2, 2, 2, 1}}, "VALID");
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]");
+
+ // 3x3 input, 1x1 filter, 2x1x1 stride
+ set_op({{1, 2, 1, 1, 1}}, "VALID");
+ INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]");
+
+ // 4x4 input, 2x2 filter, 1x1 stride
+ set_op({{1, 1, 1, 1, 1}}, "SAME");
+ INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,4,4,4,d1_4]");
+}
+
TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) {
ShapeInferenceTestOp op("DepthwiseConv2dNative");
std::vector<int32> strides = {{1, 1, 1, 1}};
@@ -512,6 +561,26 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) {
INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]");
}
+TEST(CommonShapeFnsTest, Pool3DShapeTest) {
+ ShapeInferenceTestOp op("MaxPool3D");
+ auto set_op = [&op](const std::vector<int32>& strides,
+ const std::vector<int32>& ksizes, const string& padding) {
+ TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D")
+ .Input("input", 0, DT_FLOAT)
+ .Attr("strides", strides)
+ .Attr("ksize", ksizes)
+ .Attr("padding", padding)
+ .Finalize(&op.node_def));
+ };
+
+ // Most of the functionality is tested by conv-like shapes,
+ // so we check that we handle the extra dimension properly.
+
+ // 2x3x4 stride, 1x1x1 filter.
+ set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID");
+ INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]");
+}
+
TEST(CommonShapeFnsTest, UnknownShapeTest) {
{
// Single output