aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/core/kernels/conv_ops_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/core/kernels/conv_ops_test.cc')
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc79
1 files changed, 70 insertions, 9 deletions
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index 228f2d5def..f955e6a6b6 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -121,22 +121,15 @@ class FusedResizePadConvOpTest : public OpsTestBase {
auto root = tensorflow::Scope::NewRootScope();
using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
- const size_t input_data_size = input_height * input_width * input_depth;
Tensor input_data(DT_FLOAT,
TensorShape({1, input_height, input_width, input_depth}));
- for (int i = 0; i < input_data_size; ++i) {
- input_data.flat<float>()(i) = i + 1.0f;
- }
+ test::FillIota<float>(&input_data, 1.0f);
Output input =
Const(root.WithOpName("input"), Input::Initializer(input_data));
- const size_t filter_data_size =
- filter_size * filter_size * filter_count * input_depth;
Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
input_depth, filter_count}));
- for (int i = 0; i < filter_data_size; ++i) {
- filter_data.flat<float>()(i) = i + 1.0f;
- }
+ test::FillIota<float>(&filter_data, 1.0f);
Output filter =
Const(root.WithOpName("filter"), Input::Initializer(filter_data));
@@ -173,6 +166,54 @@ class FusedResizePadConvOpTest : public OpsTestBase {
test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
}
+
+ void CompareFusedPadOnlyAndSeparate(int input_width, int input_height,
+ int input_depth, int y_padding,
+ int x_padding, int filter_size,
+ int filter_count, string pad_mode,
+ int stride, string padding) {
+ auto root = tensorflow::Scope::NewRootScope();
+ using namespace ::tensorflow::ops; // NOLINT(build/namespaces)
+
+ Tensor input_data(DT_FLOAT,
+ TensorShape({1, input_height, input_width, input_depth}));
+ test::FillIota<float>(&input_data, 1.0f);
+ Output input =
+ Const(root.WithOpName("input"), Input::Initializer(input_data));
+
+ Tensor filter_data(DT_FLOAT, TensorShape({filter_size, filter_size,
+ input_depth, filter_count}));
+ test::FillIota<float>(&filter_data, 1.0f);
+ Output filter =
+ Const(root.WithOpName("filter"), Input::Initializer(filter_data));
+
+ Output paddings =
+ Const(root.WithOpName("paddings"),
+ {{0, 0}, {y_padding, y_padding}, {x_padding, x_padding}, {0, 0}});
+ Output mirror_pad =
+ MirrorPad(root.WithOpName("mirror_pad"), input, paddings, pad_mode);
+ Output conv = Conv2D(root.WithOpName("conv"), mirror_pad, filter,
+ {1, stride, stride, 1}, padding);
+
+ Output fused_conv =
+ FusedPadConv2D(root.WithOpName("fused_conv"), input, paddings, filter,
+ pad_mode, {1, stride, stride, 1}, padding);
+
+ tensorflow::GraphDef graph;
+ TF_ASSERT_OK(root.ToGraphDef(&graph));
+
+ std::unique_ptr<tensorflow::Session> session(
+ tensorflow::NewSession(tensorflow::SessionOptions()));
+ TF_ASSERT_OK(session->Create(graph));
+
+ std::vector<Tensor> unfused_tensors;
+ TF_ASSERT_OK(session->Run({}, {"conv"}, {}, &unfused_tensors));
+
+ std::vector<Tensor> fused_tensors;
+ TF_ASSERT_OK(session->Run({}, {"fused_conv"}, {}, &fused_tensors));
+
+ test::ExpectTensorNear<float>(unfused_tensors[0], fused_tensors[0], 1e-5);
+ }
};
TEST_F(FusedResizePadConvOpTest, HandwrittenConv) { HandwrittenConv(); }
@@ -237,4 +278,24 @@ TEST_F(FusedResizePadConvOpTest, ResizeAndPadSymmetricComparative) {
"SAME");
}
+TEST_F(FusedResizePadConvOpTest, NoResizeIdentityComparative) {
+ CompareFusedPadOnlyAndSeparate(10, 10, 1, 0, 0, 1, 1, "REFLECT", 1, "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizeConvOnlyComparative) {
+ CompareFusedPadOnlyAndSeparate(10, 10, 3, 0, 0, 4, 4, "REFLECT", 1, "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyComparative) {
+ CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "REFLECT", 1, "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizePadOnlyWithChannelsComparative) {
+ CompareFusedPadOnlyAndSeparate(4, 4, 3, 2, 2, 1, 1, "REFLECT", 1, "SAME");
+}
+
+TEST_F(FusedResizePadConvOpTest, NoResizePadOnlySymmetricComparative) {
+ CompareFusedPadOnlyAndSeparate(4, 4, 1, 2, 2, 1, 1, "SYMMETRIC", 1, "SAME");
+}
+
} // namespace tensorflow