diff options
author | 2018-04-23 15:50:56 -0700 | |
---|---|---|
committer | 2018-04-23 15:53:42 -0700 | |
commit | c8a1eeb98ca394d0330bead37b446bce998bb3d5 (patch) | |
tree | 39e87e036717dcf7a5a08405e802d88c2a34e6d0 | |
parent | 2f2d4745836fdcf4bf365644017a900d98bd6206 (diff) |
[XLA] Redesign: migrate convolution tests.
PiperOrigin-RevId: 193998684
-rw-r--r-- | tensorflow/compiler/xla/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/reference_util.cc | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc | 38 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/convolution_variants_test.cc | 167 |
4 files changed, 116 insertions, 97 deletions
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD index 88f37433a5..1af9cb6d2a 100644 --- a/tensorflow/compiler/xla/BUILD +++ b/tensorflow/compiler/xla/BUILD @@ -605,8 +605,8 @@ cc_library( ":util", ":window_util", ":xla_data_proto", - "//tensorflow/compiler/xla/client:computation_builder", "//tensorflow/compiler/xla/client:padding", + "//tensorflow/compiler/xla/client/xla_client:xla_builder", "//tensorflow/compiler/xla/service:hlo", "//tensorflow/compiler/xla/service:hlo_evaluator", "//tensorflow/compiler/xla/service:shape_inference", diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc index ad3a28e119..df9dbc5830 100644 --- a/tensorflow/compiler/xla/reference_util.cc +++ b/tensorflow/compiler/xla/reference_util.cc @@ -18,7 +18,7 @@ limitations under the License. #include <array> #include <utility> -#include "tensorflow/compiler/xla/client/computation_builder.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h" #include "tensorflow/compiler/xla/service/hlo_evaluator.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" @@ -90,7 +90,7 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl( Padding padding) { return ConvArray3DGeneralDimensionsDilated( lhs, rhs, kernel_stride, padding, 1, 1, - ComputationBuilder::CreateDefaultConvDimensionNumbers(1)); + XlaBuilder::CreateDefaultConvDimensionNumbers(1)); } /*static*/ std::unique_ptr<Array3D<float>> @@ -140,7 +140,7 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated( std::pair<int64, int64> kernel_stride, Padding padding) { return ConvArray4DGeneralDimensions( lhs, rhs, kernel_stride, padding, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); } /* static */ std::unique_ptr<Array4D<float>> diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc index 896b34fb6e..b5a42e3059 100644 --- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc @@ -18,9 +18,9 @@ limitations under the License. #include <memory> #include "tensorflow/compiler/xla/array4d.h" -#include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/statusor.h" @@ -34,13 +34,35 @@ limitations under the License. namespace xla { namespace { +StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers( + int64 input_batch, int64 input_feature, int64 input_first_spatial, + int64 input_second_spatial, int64 output_batch, int64 output_feature, + int64 output_first_spatial, int64 output_second_spatial, + int64 kernel_output_feature, int64 kernel_input_feature, + int64 kernel_first_spatial, int64 kernel_second_spatial) { + ConvolutionDimensionNumbers dimension_numbers; + dimension_numbers.set_input_batch_dimension(input_batch); + dimension_numbers.set_input_feature_dimension(input_feature); + dimension_numbers.add_input_spatial_dimensions(input_first_spatial); + dimension_numbers.add_input_spatial_dimensions(input_second_spatial); + dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature); + dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature); + dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial); + dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial); + dimension_numbers.set_output_batch_dimension(output_batch); + dimension_numbers.set_output_feature_dimension(output_feature); + dimension_numbers.add_output_spatial_dimensions(output_first_spatial); + dimension_numbers.add_output_spatial_dimensions(output_second_spatial); + TF_RETURN_IF_ERROR(XlaBuilder::Validate(dimension_numbers)); + return dimension_numbers; +} + class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {}; // Tests the convolution operation with invalid input dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("input are not unique")); @@ -49,8 +71,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) { // Tests the convolution operation with invalid weight dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, - 2, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("weight are not unique")); @@ -59,8 +80,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) { // Tests the convolution operation with invalid output dimension numbers. TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) { auto dimension_numbers_status = - ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, - 1, 2, 3); + CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, 1, 2, 3); ASSERT_FALSE(dimension_numbers_status.ok()); ASSERT_THAT(dimension_numbers_status.status().error_message(), ::testing::HasSubstr("output are not unique")); @@ -76,14 +96,14 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest, client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array)) .ConsumeValueOrDie(); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(*input_array); auto weight = builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight"); auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid); ConvolutionDimensionNumbers dim_nums = - ComputationBuilder::CreateDefaultConvDimensionNumbers(); + XlaBuilder::CreateDefaultConvDimensionNumbers(); // Swap batch_dimension and feature_dimension. int64 old_input_batch_dim = dim_nums.input_batch_dimension(); int64 old_output_batch_dim = dim_nums.output_batch_dimension(); diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc index 9c1145def8..50d6e25d86 100644 --- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc +++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc @@ -28,6 +28,7 @@ limitations under the License. #include "tensorflow/compiler/xla/client/computation_builder.h" #include "tensorflow/compiler/xla/client/local_client.h" #include "tensorflow/compiler/xla/client/padding.h" +#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h" #include "tensorflow/compiler/xla/literal_util.h" #include "tensorflow/compiler/xla/reference_util.h" #include "tensorflow/compiler/xla/tests/client_library_test_base.h" @@ -52,7 +53,7 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase { }; XLA_TEST_F(ConvolutionVariantsTest, Minimal) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D<float> input_array(1, 1, 1, 1, {2}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -67,7 +68,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) { } XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -82,7 +83,7 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(2, 1, 3, 4); input_array.FillWithMultiples(1); @@ -99,7 +100,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 2, 1, 1, {10, 1}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -114,7 +115,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 2, {1, 2}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -129,7 +130,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -144,7 +145,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -159,7 +160,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -174,7 +175,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -189,7 +190,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array( 2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0 @@ -210,7 +211,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -225,7 +226,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -240,7 +241,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -255,7 +256,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -270,7 +271,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -285,7 +286,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 1, {1}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -300,7 +301,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -315,7 +316,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -333,7 +334,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -348,7 +349,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -363,7 +364,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3}); auto input = builder.ConstantR4FromArray4D<float>(input_array); @@ -378,7 +379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(64); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -398,7 +399,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(16 * 1 * 1 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -419,7 +420,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int bs = 16; constexpr int kx = 2; @@ -450,7 +451,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); constexpr int kx = 2; constexpr int ky = 2; @@ -482,7 +483,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(16, 1, 8, 8); for (int i0 = 0; i0 < 16; ++i0) { @@ -510,7 +511,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -536,7 +537,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(2 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -562,7 +563,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(32 * 2 * 8 * 8); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -602,7 +603,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) { } XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); Array4D<float> input_array(16, 16, 1, 1); Array4D<float> filter_array(16, 16, 1, 1); @@ -628,7 +629,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) { } XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 4 * 6); std::iota(input_data.begin(), input_data.end(), 0.0); @@ -640,14 +641,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 2, 2, {3924, 4257, 5922, 6255}); ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -659,14 +660,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5}); ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 3 * 4); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -682,8 +683,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { builder.ConvGeneralDilated( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1}, /*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2}, - /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 3, 5, {204, 40, 406, 60, 608, // @@ -693,7 +693,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) { } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -705,14 +705,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 1, 2, {23, 34}); ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -724,14 +724,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-1, 2}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 1, 5, {23, 34, 45, 50, 0}); ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -743,14 +743,14 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) { builder.ConvGeneral( /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {2, -1}}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); Array4D<float> expected(1, 1, 1, 5, {0, 1, 12, 23, 34}); ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -763,7 +763,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {3, 2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -775,7 +775,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) { ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 1 * 1 * 5); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -788,7 +788,7 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) { /*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{{0, 0}, {-3, -2}}, /*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); // input: // [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5] @@ -821,7 +821,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) { Array4D<float> input_array(bs, iz, iy, ix, input_data); Array4D<float> filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(input_array); auto filter = builder.ConstantR4FromArray4D<float>(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -854,7 +854,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) { Array4D<float> input_array(bs, iz, iy, ix, input_data); Array4D<float> filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(input_array); auto filter = builder.ConstantR4FromArray4D<float>(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -887,7 +887,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) { Array4D<float> input_array(bs, iz, iy, ix, input_data); Array4D<float> filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(input_array); auto filter = builder.ConstantR4FromArray4D<float>(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -920,7 +920,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) { Array4D<float> input_array(bs, iz, iy, ix, input_data); Array4D<float> filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(input_array); auto filter = builder.ConstantR4FromArray4D<float>(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -954,7 +954,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Array4D<float> input_array(bs, iz, iy, ix, input_data); Array4D<float> filter_array(oz, iz, ky, kx, kernel_data); - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto input = builder.ConstantR4FromArray4D<float>(input_array); auto filter = builder.ConstantR4FromArray4D<float>(filter_array); builder.Conv(input, filter, {1, 1}, Padding::kValid); @@ -966,7 +966,7 @@ XLA_TEST_F(ConvolutionVariantsTest, } XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1010,7 +1010,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1054,7 +1054,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 2 * 3 * 1); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1095,7 +1095,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) { } XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); std::vector<float> input_data(1 * 2 * 3 * 2); std::iota(input_data.begin(), input_data.end(), 1.0); @@ -1147,7 +1147,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) { // BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D<float>( Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1166,19 +1166,18 @@ XLA_TEST_F(ConvolutionVariantsTest, // BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D<float>( Array4D<float>(1, 1, 1, 1, /*values=*/{1})); auto weights = builder.ConstantR4FromArray4D<float>( Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100})); auto mirrored_weights = builder.Rev(weights, {2, 3}); - builder.ConvGeneralDilated( - gradients, mirrored_weights, - /*window_strides=*/{1, 1}, - /*padding=*/{{0, 0}, {0, 3}}, - /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + builder.ConvGeneralDilated(gradients, mirrored_weights, + /*window_strides=*/{1, 1}, + /*padding=*/{{0, 0}, {0, 3}}, + /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{}, + XlaBuilder::CreateDefaultConvDimensionNumbers()); ComputeAndCompareR4<float>(&builder, {{{{100, 0}}}}, {}, error_spec_); } @@ -1187,7 +1186,7 @@ XLA_TEST_F(ConvolutionVariantsTest, // into // BackwardInputConv([1], [1,10,100], padding=(1,1)) XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D<float>( Array4D<float>(1, 1, 1, 1, /*values=*/{1})); @@ -1208,7 +1207,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) { // However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't // support negative padding on backward convolution yet (b/32744257). XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR4FromArray4D<float>( Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3})); @@ -1224,7 +1223,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) { XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingLessThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1240,7 +1239,7 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {1, 2}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_); @@ -1248,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest, XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterLowPaddingGreaterThanHighPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1266,14 +1265,14 @@ XLA_TEST_F(ConvolutionVariantsTest, /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 0}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); // activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0 // gradients: 100,10,1 -dilate-> 100,0,10,0,1 @@ -1293,14 +1292,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) { /*window_strides=*/{1, 1}, /*padding=*/{{0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers()); + XlaBuilder::CreateDefaultConvDimensionNumbers()); builder.Transpose(forward_conv, {0, 1, 2, 3}); ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients = builder.ConstantR3FromArray3D<float>( Array3D<float>(1, 1, 1, /*value=*/1)); @@ -1314,26 +1313,26 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations = builder.ConstantR3FromArray3D<float>(Array3D<float>({{{1, 2, 3, 4}}})); auto gradients = builder.ConstantR3FromArray3D<float>(Array3D<float>({{{100, 10, 1}}})); - auto forward_conv = builder.ConvGeneralDilated( - activations, gradients, - /*window_strides=*/{1}, - /*padding=*/{{2, 1}}, - /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( - /*num_spatial_dims=*/1)); + auto forward_conv = + builder.ConvGeneralDilated(activations, gradients, + /*window_strides=*/{1}, + /*padding=*/{{2, 1}}, + /*lhs_dilation=*/{}, /*rhs_dilation=*/{2}, + XlaBuilder::CreateDefaultConvDimensionNumbers( + /*num_spatial_dims=*/1)); builder.Transpose(forward_conv, {0, 1, 2}); ComputeAndCompareR3<float>(&builder, {{{13, 24, 130}}}, {}, error_spec_); } XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto gradients_flat = Literal::CreateR1<float>({1}); auto gradients_literal = @@ -1357,7 +1356,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) { } XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { - ComputationBuilder builder(client_, TestName()); + XlaBuilder builder(TestName()); auto activations_flat = Literal::CreateR1<float>({1, 2, 3, 4}); auto activations_literal = @@ -1378,7 +1377,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) { /*window_strides=*/{1, 1, 1}, /*padding=*/{{0, 0}, {0, 0}, {2, 1}}, /*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2}, - ComputationBuilder::CreateDefaultConvDimensionNumbers( + XlaBuilder::CreateDefaultConvDimensionNumbers( /*num_spatial_dims=*/3)); builder.Transpose(forward_conv, {0, 1, 2, 3, 4}); ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_); |