diff options
author | 2018-08-03 04:05:04 -0700 | |
---|---|---|
committer | 2018-08-03 04:08:18 -0700 | |
commit | 37b48fac2c365c49373467abf5fc58c4678e700e (patch) | |
tree | 4d1195fb433912acf01557c94fae619160729b01 /tensorflow | |
parent | f74a3af20bb24b1de199e50bf9b2a405090a6666 (diff) |
[XLA:GPU] Forward batched dot to cublas instead of expanding it
This gives a huge speedup for users of batchdot. This is a minimal implementation without autotuning and without support for strided batch gemm.
PiperOrigin-RevId: 207247740
Diffstat (limited to 'tensorflow')
9 files changed, 260 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e0aae3866b..4947dd278e 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -636,7 +636,6 @@ cc_library( "//tensorflow/compiler/xla/service:buffer_liveness", "//tensorflow/compiler/xla/service:call_inliner", "//tensorflow/compiler/xla/service:conditional_simplifier", - "//tensorflow/compiler/xla/service:dot_decomposer", "//tensorflow/compiler/xla/service:executable", "//tensorflow/compiler/xla/service:flatten_call_graph", "//tensorflow/compiler/xla/service:hlo", @@ -749,6 +748,8 @@ tf_cc_test( "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/compiler/xla/service:computation_layout", "//tensorflow/compiler/xla/service:hlo", + "//tensorflow/compiler/xla/service:hlo_matchers", + "//tensorflow/compiler/xla/service:hlo_parser", "//tensorflow/compiler/xla/tests:hlo_test_base", "//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep ], diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc index dbcbabdc52..e9ba1f13eb 100644 --- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc @@ -32,18 +32,27 @@ namespace { // dimensions. struct MatrixDescriptor { MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose, - int64 matrix_num_rows, int64 matrix_num_cols) + int64 matrix_num_rows, int64 matrix_num_cols, + int64 matrix_batch_size) : data(matrix_data), transpose(needs_transpose), num_rows(matrix_num_rows), - num_cols(matrix_num_cols) {} + num_cols(matrix_num_cols), + batch_size(matrix_batch_size) {} se::DeviceMemoryBase data; bool transpose; // Whether this matrix needs to be transposed. int64 num_rows; int64 num_cols; + int64 batch_size; }; +template <typename T> +se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) { + se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory)); + return se::DeviceMemory<T>(wrapped); +} + // Performs a gemm call without an explicit algorithm on lhs_matrix and // rhs_matrix, and stores the result to output_matrix. template <typename Element> @@ -51,6 +60,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, MatrixDescriptor output_matrix, double alpha, se::Stream* stream) { DCHECK(!output_matrix.transpose); + const int64 batch_size = lhs_matrix.batch_size; + CHECK_EQ(batch_size, rhs_matrix.batch_size); + CHECK_EQ(batch_size, output_matrix.batch_size); se::DeviceMemory<Element> lhs_data(lhs_matrix.data); se::DeviceMemory<Element> rhs_data(rhs_matrix.data); se::DeviceMemory<Element> output_data(output_matrix.data); @@ -61,13 +73,54 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, : se::blas::Transpose::kNoTranspose; auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols; + if (batch_size == 1) { + return stream + ->ThenBlasGemm( + lhs_transpose, rhs_transpose, output_matrix.num_rows, + output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, + lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, + &output_data, /*leading dim of output=*/output_matrix.num_rows) + .ok(); + } + + // Create the buffers for batched gemm. + // TODO(b/112111608): We could avoid all of this and also make it faster by + // using cuBLAS 8's strided batched gemm. + using DeviceMemoryType = se::DeviceMemory<Element>; + std::vector<DeviceMemoryType> a_device_memory; + std::vector<DeviceMemoryType> b_device_memory; + std::vector<DeviceMemoryType> c_device_memory; + std::vector<DeviceMemoryType*> a_ptrs; + std::vector<DeviceMemoryType*> b_ptrs; + std::vector<DeviceMemoryType*> c_ptrs; + a_device_memory.reserve(batch_size); + b_device_memory.reserve(batch_size); + c_device_memory.reserve(batch_size); + a_ptrs.reserve(batch_size); + b_ptrs.reserve(batch_size); + c_ptrs.reserve(batch_size); + auto* a_base_ptr = static_cast<Element*>(lhs_data.opaque()); + auto* b_base_ptr = static_cast<Element*>(rhs_data.opaque()); + auto* c_base_ptr = static_cast<Element*>(output_data.opaque()); + for (int64 i = 0; i < batch_size; ++i) { + a_device_memory.push_back(AsDeviceMemory( + a_base_ptr + i * lhs_matrix.num_rows * lhs_matrix.num_cols)); + b_device_memory.push_back(AsDeviceMemory( + b_base_ptr + i * rhs_matrix.num_rows * rhs_matrix.num_cols)); + c_device_memory.push_back(AsDeviceMemory( + c_base_ptr + i * output_matrix.num_rows * output_matrix.num_cols)); + a_ptrs.push_back(&a_device_memory.back()); + b_ptrs.push_back(&b_device_memory.back()); + c_ptrs.push_back(&c_device_memory.back()); + } return stream - ->ThenBlasGemm( + ->ThenBlasGemmBatched( lhs_transpose, rhs_transpose, output_matrix.num_rows, output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha, - lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data, - /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, - &output_data, /*leading dim of output=*/output_matrix.num_rows) + a_ptrs, /*leading dim of LHS=*/lhs_matrix.num_rows, b_ptrs, + /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, c_ptrs, + /*leading dim of output=*/output_matrix.num_rows, batch_size) .ok(); } @@ -94,6 +147,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix, se::blas::ProfileResult* output_profile_result) { DCHECK(!output_matrix.transpose); + CHECK_EQ(1, lhs_matrix.batch_size); + CHECK_EQ(1, rhs_matrix.batch_size); + CHECK_EQ(1, output_matrix.batch_size); + se::DeviceMemory<Element> lhs_data(lhs_matrix.data); se::DeviceMemory<Element> rhs_data(rhs_matrix.data); se::DeviceMemory<Element> output_data(output_matrix.data); @@ -270,12 +327,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, se::DeviceMemoryBase output_data = buffer_allocations.GetDeviceAddress(output_buffer_); + DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(output_shape_)); + + int64 row_dim = dim_nums.lhs_batch_dimensions_size(); + int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1; + int64 batch_size = std::accumulate(output_shape_.dimensions().begin(), + output_shape_.dimensions().end() - 2, 1, + std::multiplies<int64>()); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_NE(row_dim, batch_dim); + CHECK_NE(col_dim, batch_dim); + } + + // Verify that the non-batch dimensions are minor-most. This is required for + // efficient access. + for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) { + CHECK_LT(shape->layout().minor_to_major(row_dim), 2); + CHECK_LT(shape->layout().minor_to_major(col_dim), 2); + } + // BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between // matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of // their layout. Therefore, we should treat dimension 0 as row and dimension 1 // as column when mapping a matrix Dot to BLAS gemm. - int64 output_num_rows = output_shape_.dimensions(0); - int64 output_num_cols = output_shape_.dimensions(1); + int64 output_num_rows = output_shape_.dimensions(row_dim); + int64 output_num_cols = output_shape_.dimensions(col_dim); // BLAS gemm expects the inputs and the output are in column-major order. // Therefore, we need to convert dot between row-major matrices to that @@ -298,31 +380,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, // the leading dimension of the LHS matrix of gemm is the number of rows in // B^T and thus the number of columns in B. - auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape, - bool transpose) -> MatrixDescriptor { - bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0; - bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) != - LayoutUtil::Minor(output_shape_.layout(), 0); - return MatrixDescriptor(data, transpose ^ layout_mismatch, - shape.dimensions(is_row_major), - shape.dimensions(!is_row_major)); + auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape, + bool transpose) -> MatrixDescriptor { + bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0; + bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) != + LayoutUtil::Minor(output_shape_.layout(), row_dim); + return MatrixDescriptor( + data, transpose ^ layout_mismatch, + shape.dimensions(row_dim + static_cast<int64>(is_row_major)), + shape.dimensions(row_dim + static_cast<int64>(!is_row_major)), + batch_size); }; - DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction()); - const MatrixDescriptor lhs_descriptor = make_descriptor( - lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0); + lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim); const MatrixDescriptor rhs_descriptor = make_descriptor( - rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1); + rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim); // Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to // autotune this gemm to figure out the best algorithm. - auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, - MatrixDescriptor output_matrix, se::Stream* stream) { + auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix, + MatrixDescriptor output_matrix, se::Stream* stream) { PrimitiveType element_type = output_shape_.element_type(); se::blas::ComputationType computation_type = GetBlasComputationType(element_type); + // TODO(b/112111608): Implement auto tune for batched gemm. + if (batch_size != 1) { + return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix, + alpha_, stream); + } + auto thunk_name = [&] { return hlo_instruction() != nullptr ? hlo_instruction()->ToString() : "<null>"; @@ -368,16 +456,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations, auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction()); bool launch_ok; - if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) { - launch_ok = launch( - lhs_descriptor, rhs_descriptor, - MatrixDescriptor(output_data, false, output_num_rows, output_num_cols), - stream); + if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) { + launch_ok = launch(lhs_descriptor, rhs_descriptor, + MatrixDescriptor(output_data, false, output_num_rows, + output_num_cols, batch_size), + stream); } else { - launch_ok = launch( - rhs_descriptor, lhs_descriptor, - MatrixDescriptor(output_data, false, output_num_cols, output_num_rows), - stream); + launch_ok = launch(rhs_descriptor, lhs_descriptor, + MatrixDescriptor(output_data, false, output_num_cols, + output_num_rows, batch_size), + stream); } if (!launch_ok) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 6ac5dfbcd5..d033faee8d 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -176,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints( TF_RETURN_IF_ERROR( AddBackendConstraintsToDnnConvCustomCall(instruction, constraints)); } + + // For batched dot we require the default layout. + // TODO(b/112111608): This is overly conservative, the only real restriction + // is that batch dimensions must be major. + if (instruction->opcode() == HloOpcode::kDot && + ImplementedAsGemm(*instruction) && + instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) { + // Verify that the batch dims come before the row and col dims. + const DotDimensionNumbers& dim_nums = + instruction->dot_dimension_numbers(); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size(), + dim_nums.rhs_batch_dimensions_size()); + CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2, + ShapeUtil::Rank(instruction->shape())); + for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) { + CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2); + } + + // Set both inputs and the output to default layout. + Shape op0_shape = instruction->operand(0)->shape(); + LayoutUtil::SetToDefaultLayout(&op0_shape); + Shape op1_shape = instruction->operand(1)->shape(); + LayoutUtil::SetToDefaultLayout(&op1_shape); + Shape output_shape = instruction->shape(); + LayoutUtil::SetToDefaultLayout(&output_shape); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op0_shape, instruction, 0)); + TF_RETURN_IF_ERROR( + constraints->SetOperandLayout(op1_shape, instruction, 1)); + TF_RETURN_IF_ERROR( + constraints->SetInstructionLayout(output_shape, instruction)); + } } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc index 95f78ae293..286547ebae 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc @@ -20,8 +20,10 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" +#include "tensorflow/compiler/xla/service/hlo_matchers.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/service/hlo_parser.h" #include "tensorflow/compiler/xla/shape_layout.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/tests/hlo_test_base.h" @@ -31,6 +33,8 @@ namespace xla { namespace gpu { namespace { +namespace op = xla::testing::opcode_matchers; + using LayoutAssignmentTest = HloTestBase; TEST_F(LayoutAssignmentTest, Elementwise) { @@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) { } } +TEST_F(LayoutAssignmentTest, DotLayout) { + const char* hlo_text = R"( + HloModule DotLayout + ENTRY dot { + p0 = f32[8,8,256,64]{3,1,2,0} parameter(0) + p1 = f32[8,8,256,64]{3,1,2,0} parameter(1) + ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1), + lhs_batch_dims={0,1}, lhs_contracting_dims={3}, + rhs_batch_dims={0,1}, rhs_contracting_dims={3} + })"; + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module, + ParseHloString(hlo_text)); + + ComputationLayout computation_layout( + module->entry_computation()->ComputeProgramShape()); + GpuLayoutAssignment layout_assignment(&computation_layout, + backend().default_stream_executor()); + EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie()); + + Shape expected_shape = + ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0}); + EXPECT_THAT(module->entry_computation()->root_instruction(), + op::Dot(op::ShapeWithLayout(expected_shape), + op::ShapeWithLayout(expected_shape))); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc index 6352b330d1..d74c1a0243 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc @@ -38,14 +38,16 @@ namespace gpu { namespace { // Return whether the given shape is a matrix with no padding. -bool IsRank2WithNoPadding(const Shape& shape) { - return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape); +bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) { + return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 && + !LayoutUtil::IsPadded(shape); } // In a gemm operation where output = lhs * rhs, check whether the given shapes // are valid for the operation. bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, - const Shape& output_shape) { + const Shape& output_shape, + int64 batch_dimensions_size) { // The inputs and the output must // 1) be matrices with no padding and a non-zero number of elements, // 2) have an allowed element type. @@ -53,9 +55,10 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape, bool type_is_allowed = (output_primitive_type == F16 || output_primitive_type == F32 || output_primitive_type == F64); - return type_is_allowed && IsRank2WithNoPadding(lhs_shape) && - IsRank2WithNoPadding(rhs_shape) && - IsRank2WithNoPadding(output_shape) && + return type_is_allowed && + IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) && + IsRank2WithNoPadding(output_shape, batch_dimensions_size) && !ShapeUtil::IsZeroElementArray(lhs_shape) && !ShapeUtil::IsZeroElementArray(rhs_shape); } @@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) { CHECK_EQ(dot.opcode(), HloOpcode::kDot); const Shape& lhs_shape = dot.operand(0)->shape(); const Shape& rhs_shape = dot.operand(1)->shape(); + const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); // If gemm can accept the operand shapes, use it rather than a custom // kernel. - if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) { + if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(), + dim_numbers.lhs_batch_dimensions_size())) { // The size of the reduction dimension should match. The shape inference // guarantees this invariant, so the check here is for programming // errors. - const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers(); CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)), rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0))); return true; diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc index 290e2f73dc..541cacf697 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc @@ -454,6 +454,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const Shape& lhs_shape = lhs_instruction->shape(); const Shape& rhs_shape = rhs_instruction->shape(); + const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); + CHECK_EQ(dnums.lhs_batch_dimensions_size(), + dnums.rhs_batch_dimensions_size()); // TODO(b/110211620): Convert to use i32 index_type when it is possible. llvm::Type* index_type = b_.getInt64Ty(); @@ -489,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { const int64 lhs_reduction_dimension = ShapeUtil::GetDimensionNumber(lhs_shape, -1); const int64 rhs_reduction_dimension = - ShapeUtil::Rank(rhs_shape) >= 2 + ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size() ? ShapeUtil::GetDimensionNumber(rhs_shape, -2) - : 0; + : dnums.lhs_batch_dimensions_size(); + + // Check that the batch dims don't cover the last two dims. + for (int64 batch_dim : dnums.lhs_batch_dimensions()) { + CHECK_NE(lhs_reduction_dimension, batch_dim); + CHECK_NE(rhs_reduction_dimension, batch_dim); + } // Verify the reduction dimension in the two operands are the same size. TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) == @@ -506,6 +515,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest( rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs"); + // We don't have to iterate over the batch dimensions in both arrays, simplify + // the loop nest of the rhs. + for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) { + DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i)); + rhs_index[i] = lhs_index[i]; + } + // Create the reduction loop which does the sum of products reduction. std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop( /*start_index=*/0, @@ -568,7 +584,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) { target_index.push_back(lhs_index[dimension]); } } - for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) { + // Skip over the batch dimensions to not have them in the index twice. + for (size_t dimension = dnums.lhs_batch_dimensions_size(); + dimension < rhs_index.size(); ++dimension) { if (dimension != rhs_reduction_dimension) { target_index.push_back(rhs_index[dimension]); } diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index f61a977ad4..d5ecae88ed 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -345,11 +345,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) { } Status IrEmitterUnnested::HandleDot(HloInstruction* dot) { - const DotDimensionNumbers& dnums = dot->dot_dimension_numbers(); - if (dnums.lhs_batch_dimensions_size() > 0 || - dnums.rhs_batch_dimensions_size() > 0) { - return Unimplemented("Dot with batch dimensions not implemented."); - } if (ImplementedAsGemm(*dot)) { thunk_sequence_->emplace_back(BuildGemmThunk(dot)); return Status::OK(); diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index 7a683ede54..8fa0439006 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -34,7 +34,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/buffer_liveness.h" #include "tensorflow/compiler/xla/service/call_inliner.h" #include "tensorflow/compiler/xla/service/conditional_simplifier.h" -#include "tensorflow/compiler/xla/service/dot_decomposer.h" #include "tensorflow/compiler/xla/service/flatten_call_graph.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" @@ -148,7 +147,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // support BF16 operations without directly implementing a BF16 lowering for // most ops. pipeline.AddPass<HloElementTypeConverter>(BF16, F32); - pipeline.AddPass<DotDecomposer>(); { auto& pass = diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc index cfd36abf47..f11d274aab 100644 --- a/tensorflow/compiler/xla/tests/dot_operation_test.cc +++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc @@ -612,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) { {x_data.get(), y_data.get()}, this->error_spec_); } -XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) { using T = TypeParam; XlaBuilder builder(this->TestName()); @@ -648,6 +648,48 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) { {x_data.get(), y_data.get()}, this->error_spec_); } +XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) { + using T = TypeParam; + + XlaBuilder builder(this->TestName()); + auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), + "x"); + auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}), + "y"); + + DotDimensionNumbers dnums; + dnums.add_lhs_contracting_dimensions(3); + dnums.add_rhs_contracting_dimensions(2); + dnums.add_lhs_batch_dimensions(0); + dnums.add_lhs_batch_dimensions(1); + dnums.add_rhs_batch_dimensions(0); + dnums.add_rhs_batch_dimensions(1); + + DotGeneral(x, y, dnums); + + auto x_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>( + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{9.0f, 10.0f}, {11.0f, 12.0f}}, + {{13.0f, 14.0f}, {15.0f, 16.0f}}}})) + .ConsumeValueOrDie(); + + auto y_data = + this->client_ + ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>( + {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}}, + {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}})) + .ConsumeValueOrDie(); + + this->template ComputeAndCompareR4<T>( + &builder, + /*expected=*/ + {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}}, + {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}}, + {x_data.get(), y_data.get()}, this->error_spec_); +} + XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) { using T = TypeParam; for (bool transpose_lhs : {false, true}) { |