aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tests/convolution_test.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tests/convolution_test.cc')
-rw-r--r--tensorflow/compiler/xla/tests/convolution_test.cc160
1 files changed, 105 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc
index 7425f778a6..0cc2e5fb7e 100644
--- a/tensorflow/compiler/xla/tests/convolution_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_test.cc
@@ -82,127 +82,177 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) {
ComputationBuilder builder(client_, TestName());
auto lhs = builder.ConstantR4FromArray4D<float>(*alhs);
auto rhs = builder.ConstantR4FromArray4D<float>(*arhs);
- auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
+ builder.Conv(lhs, rhs, {1, 1}, Padding::kValid);
- ComputeAndCompare(&builder, conv, {}, error_spec_);
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid);
+
+ ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_);
}
TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) {
ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
-
- Array4D<float> input_data(1, 1, 1, 2);
- input_data.FillWithYX(Array2D<float>({
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
+
+ Array4D<float> input(1, 1, 1, 2);
+ input.FillWithYX(Array2D<float>({
{1, 2},
}));
- Array4D<float> filter_data(1, 1, 1, 2);
- filter_data.FillWithYX(Array2D<float>({
+ Array4D<float> filter(1, 1, 1, 2);
+ filter.FillWithYX(Array2D<float>({
{5, 6},
}));
- ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
- error_spec_);
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
}
// Tests valid padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) {
ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kValid);
+ }
- Array4D<float> input_data(1, 1, 4, 4);
+ Array4D<float> input(1, 1, 4, 4);
// clang-format off
- input_data.FillWithYX(Array2D<float>({
+ input.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter_data(1, 1, 2, 2);
+ Array4D<float> filter(1, 1, 2, 2);
// clang-format off
- filter_data.FillWithYX(Array2D<float>({
+ filter.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
- ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
- error_spec_);
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid);
+
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
}
// Tests same padding for 2D convolution in raster space.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) {
ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
- Array4D<float> input_data(1, 1, 4, 4);
+ Array4D<float> input(1, 1, 4, 4);
// clang-format off
- input_data.FillWithYX(Array2D<float>({
+ input.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter_data(1, 1, 2, 2);
+ Array4D<float> filter(1, 1, 2, 2);
// clang-format off
- filter_data.FillWithYX(Array2D<float>({
+ filter.FillWithYX(Array2D<float>({
{5, 6},
{7, 8},
}));
// clang-format on
- ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
- error_spec_);
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
}
// Tests same padding for 2D convolution in raster space with an odd sized
// kernel.
TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) {
ComputationBuilder builder(client_, TestName());
- Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
- Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
- auto input = builder.Parameter(0, input_shape, "input");
- auto filter = builder.Parameter(1, filter_shape, "filter");
- auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ {
+ Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4});
+ Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3});
+ auto input = builder.Parameter(0, input_shape, "input");
+ auto filter = builder.Parameter(1, filter_shape, "filter");
+ builder.Conv(input, filter, {1, 1}, Padding::kSame);
+ }
- Array4D<float> input_data(1, 1, 4, 4);
+ Array4D<float> input(1, 1, 4, 4);
// clang-format off
- input_data.FillWithYX(Array2D<float>({
+ input.FillWithYX(Array2D<float>({
{1, 2, 3, 4 },
{5, 6, 7, 8 },
{9, 10, 11, 12},
{13, 14, 15, 16},
}));
// clang-format on
- Array4D<float> filter_data(1, 1, 3, 3);
+ Array4D<float> filter(1, 1, 3, 3);
// clang-format off
- filter_data.FillWithYX(Array2D<float>({
+ filter.FillWithYX(Array2D<float>({
{ 5, 6, 7},
{ 8, 9, 10},
{11, 12, 13},
}));
// clang-format on
- ComputeAndCompare(&builder, conv,
- {*Literal::CreateFromArray(input_data),
- *Literal::CreateFromArray(filter_data)},
- error_spec_);
+
+ std::unique_ptr<Array4D<float>> aexpected =
+ ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame);
+
+ auto input_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(input))
+ .ConsumeValueOrDie();
+ auto filter_literal =
+ client_->TransferToServer(*Literal::CreateR4FromArray4D(filter))
+ .ConsumeValueOrDie();
+
+ ComputeAndCompareR4<float>(&builder, *aexpected,
+ {input_literal.get(), filter_literal.get()},
+ error_spec_);
}
XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) {