diff options
Diffstat (limited to 'tensorflow/compiler/xla/tests/convolution_test.cc')
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_test.cc | 160 |
1 files changed, 105 insertions, 55 deletions
diff --git a/tensorflow/compiler/xla/tests/convolution_test.cc b/tensorflow/compiler/xla/tests/convolution_test.cc index 7425f778a6..0cc2e5fb7e 100644 --- a/tensorflow/compiler/xla/tests/convolution_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_test.cc @@ -82,127 +82,177 @@ XLA_TEST_F(ConvolutionTest, ForwardPassConvolution_3x3x256_256_OutputZ_Iota) { ComputationBuilder builder(client_, TestName()); auto lhs = builder.ConstantR4FromArray4D<float>(*alhs); auto rhs = builder.ConstantR4FromArray4D<float>(*arhs); - auto conv = builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); + builder.Conv(lhs, rhs, {1, 1}, Padding::kValid); - ComputeAndCompare(&builder, conv, {}, error_spec_); + std::unique_ptr<Array4D<float>> aexpected = + ReferenceUtil::ConvArray4D(*alhs, *arhs, {1, 1}, Padding::kValid); + + ComputeAndCompareR4<float>(&builder, *aexpected, {}, error_spec_); } TEST_F(ConvolutionTest, Convolve_1x1x1x2_1x1x1x2_Valid) { ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); - - Array4D<float> input_data(1, 1, 1, 2); - input_data.FillWithYX(Array2D<float>({ + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 1, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + } + + Array4D<float> input(1, 1, 1, 2); + input.FillWithYX(Array2D<float>({ {1, 2}, })); - Array4D<float> filter_data(1, 1, 1, 2); - filter_data.FillWithYX(Array2D<float>({ + Array4D<float> filter(1, 1, 1, 2); + filter.FillWithYX(Array2D<float>({ {5, 6}, })); - ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, - error_spec_); + std::unique_ptr<Array4D<float>> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4<float>(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } // Tests valid padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Valid) { ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kValid); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kValid); + } - Array4D<float> input_data(1, 1, 4, 4); + Array4D<float> input(1, 1, 4, 4); // clang-format off - input_data.FillWithYX(Array2D<float>({ + input.FillWithYX(Array2D<float>({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D<float> filter_data(1, 1, 2, 2); + Array4D<float> filter(1, 1, 2, 2); // clang-format off - filter_data.FillWithYX(Array2D<float>({ + filter.FillWithYX(Array2D<float>({ {5, 6}, {7, 8}, })); // clang-format on - ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, - error_spec_); + + std::unique_ptr<Array4D<float>> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kValid); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4<float>(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } // Tests same padding for 2D convolution in raster space. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x2x2_Same) { ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 2, 2}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kSame); + } - Array4D<float> input_data(1, 1, 4, 4); + Array4D<float> input(1, 1, 4, 4); // clang-format off - input_data.FillWithYX(Array2D<float>({ + input.FillWithYX(Array2D<float>({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D<float> filter_data(1, 1, 2, 2); + Array4D<float> filter(1, 1, 2, 2); // clang-format off - filter_data.FillWithYX(Array2D<float>({ + filter.FillWithYX(Array2D<float>({ {5, 6}, {7, 8}, })); // clang-format on - ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, - error_spec_); + + std::unique_ptr<Array4D<float>> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4<float>(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } // Tests same padding for 2D convolution in raster space with an odd sized // kernel. TEST_F(ConvolutionTest, Convolve_1x1x4x4_1x1x3x3_Same) { ComputationBuilder builder(client_, TestName()); - Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); - Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); - auto input = builder.Parameter(0, input_shape, "input"); - auto filter = builder.Parameter(1, filter_shape, "filter"); - auto conv = builder.Conv(input, filter, {1, 1}, Padding::kSame); + { + Shape input_shape = ShapeUtil::MakeShape(F32, {1, 1, 4, 4}); + Shape filter_shape = ShapeUtil::MakeShape(F32, {1, 1, 3, 3}); + auto input = builder.Parameter(0, input_shape, "input"); + auto filter = builder.Parameter(1, filter_shape, "filter"); + builder.Conv(input, filter, {1, 1}, Padding::kSame); + } - Array4D<float> input_data(1, 1, 4, 4); + Array4D<float> input(1, 1, 4, 4); // clang-format off - input_data.FillWithYX(Array2D<float>({ + input.FillWithYX(Array2D<float>({ {1, 2, 3, 4 }, {5, 6, 7, 8 }, {9, 10, 11, 12}, {13, 14, 15, 16}, })); // clang-format on - Array4D<float> filter_data(1, 1, 3, 3); + Array4D<float> filter(1, 1, 3, 3); // clang-format off - filter_data.FillWithYX(Array2D<float>({ + filter.FillWithYX(Array2D<float>({ { 5, 6, 7}, { 8, 9, 10}, {11, 12, 13}, })); // clang-format on - ComputeAndCompare(&builder, conv, - {*Literal::CreateFromArray(input_data), - *Literal::CreateFromArray(filter_data)}, - error_spec_); + + std::unique_ptr<Array4D<float>> aexpected = + ReferenceUtil::ConvArray4D(input, filter, {1, 1}, Padding::kSame); + + auto input_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(input)) + .ConsumeValueOrDie(); + auto filter_literal = + client_->TransferToServer(*Literal::CreateR4FromArray4D(filter)) + .ConsumeValueOrDie(); + + ComputeAndCompareR4<float>(&builder, *aexpected, + {input_literal.get(), filter_literal.get()}, + error_spec_); } XLA_TEST_F(ConvolutionTest, Convolve1D_1x2x5_1x2x2_Valid) { |