diff options
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_test.cc')
-rw-r--r-- | tensorflow/core/kernels/conv_ops_test.cc | 79 |
1 files changed, 70 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc index 228f2d5def..f955e6a6b6 100644 --- a/tensorflow/core/kernels/conv_ops_test.cc +++ b/tensorflow/core/kernels/conv_ops_test.cc @@ -121,22 +121,15 @@ class FusedResizePadConvOpTest : public OpsTestBase { auto root = tensorflow::Scope::NewRootScope(); using namespace ::tensorflow::ops; // NOLINT(build/namespaces) - const size_t input_data_size = input_height * input_width * input_depth; Tensor input_data(DT_FLOAT, TensorShape({1, input_height, input_width, input_depth})); - for (int i = 0; i < input_data_size; ++i) { - input_data.flat<float>()(i) = i + 1.0f; - } + test::FillIota<float>(&input_data, 1.0f); Output input = Const(root.WithOpName("input"), Input::Initializer(input_data)); - const size_t filter_data_size = - filter_size * filter_size * filter_count * input_depth; Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size, input_depth, filter_count})); - for (int i = 0; i < filter_data_size; ++i) { - filter_data.flat<float>()(i) = i + 1.0f; - } + test::FillIota<float>(&filter_data, 1.0f); Output filter = Const(root.WithOpName("filter"), Input::Initializer(filter_data)); @@ -173,6 +166,54 @@ class FusedResizePadConvOpTest : public OpsTestBase { test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5); } + + void CompareFusedPadOnlyAndSeparate(int input_width, int input_height, + int input_depth, int y_padding, + int x_padding, int filter_size, + int filter_count, string pad_mode, + int stride, string padding) { + auto root = tensorflow::Scope::NewRootScope(); + using namespace ::tensorflow::ops; // NOLINT(build/namespaces) + + Tensor input_data(DT_FLOAT, + TensorShape({1, input_height, input_width, input_depth})); + test::FillIota<float>(&input_data, 1.0f); + Output input = + Const(root.WithOpName("input"), Input::Initializer(input_data)); + + Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size, + input_depth, filter_count})); + test::FillIota<float>(&filter_data, 1.0f); + Output filter = + Const(root.WithOpName("filter"), Input::Initializer(filter_data)); + + Output paddings = + Const(root.WithOpName("paddings"), + {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}}); + Output mirror_pad = + MirrorPad(root.WithOpName("mirror_pad"), input, paddings, pad_mode); + Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter, + {1, stride, stride, 1}, padding); + + Output fused_conv = + FusedPadConv2D(root.WithOpName("fused_conv"), input, paddings, filter, + pad_mode, {1, stride, stride, 1}, padding); + + tensorflow::GraphDef graph; + TF_ASSERT_OK(root.ToGraphDef(&graph)); + + std::unique_ptr<tensorflow::Session> session( + tensorflow::NewSession(tensorflow::SessionOptions())); + TF_ASSERT_OK(session->Create(graph)); + + std::vector<Tensor> unfused_tensors; + TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors)); + + std::vector<Tensor> fused_tensors; + TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors)); + + test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5); + } }; TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); } @@ -237,4 +278,24 @@ TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) { "SAME"); } +TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparative) { + CompareFusedPadOnlyAndSeparate(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1, "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, NoResizeConvOnlyComparative) { + CompareFusedPadOnlyAndSeparate(10, 10, 3, 0, 0, 4, 4, "REFLECT", 1, "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyComparative) { + CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "REFLECT", 1, "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyWithChannelsComparative) { + CompareFusedPadOnlyAndSeparate(4, 4, 3, 2, 2, 1, 1, "REFLECT", 1, "SAME"); +} + +TEST_F(FusedResizePadConvOpTest, NoResizePadOnlySymmetricComparative) { + CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "SYMMETRIC", 1, "SAME"); +} + } // namespace tensorflow |