diff options
author | Vijay Vasudevan <vrv@google.com> | 2016-08-01 11:08:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2016-08-01 12:17:01 -0700 |
commit | 19ad04564a70ae0134c044666460f47714e287f1 (patch) | |
tree | dae711833b192eaee432cc30d790602a119d146c /tensorflow/core/framework/common_shape_fns_test.cc | |
parent | b1d9ef53ad6fbf1d98374471456040aecc0b4799 (diff) |
TensorFlow: Add Conv3D/MaxPool3D/AvgPool3D C++ shape inference functions .
Change: 129011665
Diffstat (limited to 'tensorflow/core/framework/common_shape_fns_test.cc')
-rw-r--r-- | tensorflow/core/framework/common_shape_fns_test.cc | 69 |
1 files changed, 69 insertions, 0 deletions
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc index eada469b17..6e0dd7f742 100644 --- a/tensorflow/core/framework/common_shape_fns_test.cc +++ b/tensorflow/core/framework/common_shape_fns_test.cc @@ -419,6 +419,55 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) { INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,4,4,d1_3]"); } +TEST(CommonShapeFnsTest, Conv3DShapeTest) { + ShapeInferenceTestOp op("Conv3D"); + auto set_op = [&op](const std::vector<int32>& strides, + const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "Conv3D") + .Input("input", 0, DT_FLOAT) + .Input("filter", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // 1x1x1 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // Invalid rank for input + INFER_ERROR("must be rank 5", op, "[4,4];[2,1,1,1]"); + // Invalid rank for filter + INFER_ERROR("must be rank 5", op, "[1,4,4,1];[2,1,1]"); + + // No unknown dims in the critical fields. + INFER_ERROR("is not known", op, "[1,?,2,2,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,?,2,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,?,1];[1,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,2,1];[?,1,1,1,1]"); + INFER_ERROR("is not known", op, "[1,2,2,2,1];[1,?,1,1,1]"); + + // input depths must match. + INFER_ERROR("Dimensions must be equal, but are 10 and 10000", op, + "[1,2,2,2,10];[1,1,1,10000,20]"); + + // 2x2x2 filter + set_op({{1, 1, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,2,2,2,1];[2,2,2,1,1]", "[d0_0,1,1,1,d1_4]"); + + // 3x3 input, 1x1 filter, 2x2 stride + set_op({{1, 2, 2, 2, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,2,2,d1_4]"); + + // 3x3 input, 1x1 filter, 2x1x1 stride + set_op({{1, 2, 1, 1, 1}}, "VALID"); + INFER_OK(op, "[1,3,3,3,1];[1,1,1,1,1]", "[d0_0,2,3,3,d1_4]"); + + // 4x4 input, 2x2 filter, 1x1 stride + set_op({{1, 1, 1, 1, 1}}, "SAME"); + INFER_OK(op, "[1,4,4,4,1];[2,2,2,1,1]", "[d0_0,4,4,4,d1_4]"); +} + TEST(CommonShapeFnsTest, DepthwiseConv2DShapeTest) { ShapeInferenceTestOp op("DepthwiseConv2dNative"); std::vector<int32> strides = {{1, 1, 1, 1}}; @@ -512,6 +561,26 @@ TEST(CommonShapeFnsTest, MaxPool2DShapeTest) { INFER_OK(op, "[1,7,5,5]", "[d0_0,3,5,5]"); } +TEST(CommonShapeFnsTest, Pool3DShapeTest) { + ShapeInferenceTestOp op("MaxPool3D"); + auto set_op = [&op](const std::vector<int32>& strides, + const std::vector<int32>& ksizes, const string& padding) { + TF_CHECK_OK(NodeDefBuilder("test", "MaxPool3D") + .Input("input", 0, DT_FLOAT) + .Attr("strides", strides) + .Attr("ksize", ksizes) + .Attr("padding", padding) + .Finalize(&op.node_def)); + }; + + // Most of the functionality is tested by conv-like shapes, + // so we check that we handle the extra dimension properly. + + // 2x3x4 stride, 1x1x1 filter. + set_op({1, 2, 3, 4, 1}, {1, 1, 1, 1, 1}, "VALID"); + INFER_OK(op, "[1,24,24,24,1]", "[d0_0,12,8,6,d0_4]"); +} + TEST(CommonShapeFnsTest, UnknownShapeTest) { { // Single output |