aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-08-03 04:05:04 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-03 04:08:18 -0700
commit37b48fac2c365c49373467abf5fc58c4678e700e (patch)
tree4d1195fb433912acf01557c94fae619160729b01 /tensorflow
parentf74a3af20bb24b1de199e50bf9b2a405090a6666 (diff)
[XLA:GPU] Forward batched dot to cublas instead of expanding it
This gives a huge speedup for users of batchdot. This is a minimal implementation without autotuning and without support for strided batch gemm. PiperOrigin-RevId: 207247740
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc150
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc32
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc31
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc20
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter.cc24
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc44
9 files changed, 260 insertions, 51 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index e0aae3866b..4947dd278e 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -636,7 +636,6 @@ cc_library(
"//tensorflow/compiler/xla/service:buffer_liveness",
"//tensorflow/compiler/xla/service:call_inliner",
"//tensorflow/compiler/xla/service:conditional_simplifier",
- "//tensorflow/compiler/xla/service:dot_decomposer",
"//tensorflow/compiler/xla/service:executable",
"//tensorflow/compiler/xla/service:flatten_call_graph",
"//tensorflow/compiler/xla/service:hlo",
@@ -749,6 +748,8 @@ tf_cc_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:hlo",
+ "//tensorflow/compiler/xla/service:hlo_matchers",
+ "//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main", # build_cleaner: keep
],
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index dbcbabdc52..e9ba1f13eb 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -32,18 +32,27 @@ namespace {
// dimensions.
struct MatrixDescriptor {
MatrixDescriptor(se::DeviceMemoryBase matrix_data, bool needs_transpose,
- int64 matrix_num_rows, int64 matrix_num_cols)
+ int64 matrix_num_rows, int64 matrix_num_cols,
+ int64 matrix_batch_size)
: data(matrix_data),
transpose(needs_transpose),
num_rows(matrix_num_rows),
- num_cols(matrix_num_cols) {}
+ num_cols(matrix_num_cols),
+ batch_size(matrix_batch_size) {}
se::DeviceMemoryBase data;
bool transpose; // Whether this matrix needs to be transposed.
int64 num_rows;
int64 num_cols;
+ int64 batch_size;
};
+template <typename T>
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ return se::DeviceMemory<T>(wrapped);
+}
+
// Performs a gemm call without an explicit algorithm on lhs_matrix and
// rhs_matrix, and stores the result to output_matrix.
template <typename Element>
@@ -51,6 +60,9 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
MatrixDescriptor output_matrix, double alpha, se::Stream* stream) {
DCHECK(!output_matrix.transpose);
+ const int64 batch_size = lhs_matrix.batch_size;
+ CHECK_EQ(batch_size, rhs_matrix.batch_size);
+ CHECK_EQ(batch_size, output_matrix.batch_size);
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@@ -61,13 +73,54 @@ bool DoGemm(MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
: se::blas::Transpose::kNoTranspose;
auto k = lhs_matrix.transpose ? lhs_matrix.num_rows : lhs_matrix.num_cols;
+ if (batch_size == 1) {
+ return stream
+ ->ThenBlasGemm(
+ lhs_transpose, rhs_transpose, output_matrix.num_rows,
+ output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
+ lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
+ &output_data, /*leading dim of output=*/output_matrix.num_rows)
+ .ok();
+ }
+
+ // Create the buffers for batched gemm.
+ // TODO(b/112111608): We could avoid all of this and also make it faster by
+ // using cuBLAS 8's strided batched gemm.
+ using DeviceMemoryType = se::DeviceMemory<Element>;
+ std::vector<DeviceMemoryType> a_device_memory;
+ std::vector<DeviceMemoryType> b_device_memory;
+ std::vector<DeviceMemoryType> c_device_memory;
+ std::vector<DeviceMemoryType*> a_ptrs;
+ std::vector<DeviceMemoryType*> b_ptrs;
+ std::vector<DeviceMemoryType*> c_ptrs;
+ a_device_memory.reserve(batch_size);
+ b_device_memory.reserve(batch_size);
+ c_device_memory.reserve(batch_size);
+ a_ptrs.reserve(batch_size);
+ b_ptrs.reserve(batch_size);
+ c_ptrs.reserve(batch_size);
+ auto* a_base_ptr = static_cast<Element*>(lhs_data.opaque());
+ auto* b_base_ptr = static_cast<Element*>(rhs_data.opaque());
+ auto* c_base_ptr = static_cast<Element*>(output_data.opaque());
+ for (int64 i = 0; i < batch_size; ++i) {
+ a_device_memory.push_back(AsDeviceMemory(
+ a_base_ptr + i * lhs_matrix.num_rows * lhs_matrix.num_cols));
+ b_device_memory.push_back(AsDeviceMemory(
+ b_base_ptr + i * rhs_matrix.num_rows * rhs_matrix.num_cols));
+ c_device_memory.push_back(AsDeviceMemory(
+ c_base_ptr + i * output_matrix.num_rows * output_matrix.num_cols));
+ a_ptrs.push_back(&a_device_memory.back());
+ b_ptrs.push_back(&b_device_memory.back());
+ c_ptrs.push_back(&c_device_memory.back());
+ }
return stream
- ->ThenBlasGemm(
+ ->ThenBlasGemmBatched(
lhs_transpose, rhs_transpose, output_matrix.num_rows,
output_matrix.num_cols, /*size of reduce dim=*/k, /*alpha=*/alpha,
- lhs_data, /*leading dim of LHS=*/lhs_matrix.num_rows, rhs_data,
- /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0,
- &output_data, /*leading dim of output=*/output_matrix.num_rows)
+ a_ptrs, /*leading dim of LHS=*/lhs_matrix.num_rows, b_ptrs,
+ /*leading dim of RHS=*/rhs_matrix.num_rows, /*beta=*/0.0, c_ptrs,
+ /*leading dim of output=*/output_matrix.num_rows, batch_size)
.ok();
}
@@ -94,6 +147,10 @@ bool DoGemmWithAlgorithm(MatrixDescriptor lhs_matrix,
se::blas::ProfileResult* output_profile_result) {
DCHECK(!output_matrix.transpose);
+ CHECK_EQ(1, lhs_matrix.batch_size);
+ CHECK_EQ(1, rhs_matrix.batch_size);
+ CHECK_EQ(1, output_matrix.batch_size);
+
se::DeviceMemory<Element> lhs_data(lhs_matrix.data);
se::DeviceMemory<Element> rhs_data(rhs_matrix.data);
se::DeviceMemory<Element> output_data(output_matrix.data);
@@ -270,12 +327,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
se::DeviceMemoryBase output_data =
buffer_allocations.GetDeviceAddress(output_buffer_);
+ DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
+ dim_nums.rhs_batch_dimensions_size());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
+ ShapeUtil::Rank(output_shape_));
+
+ int64 row_dim = dim_nums.lhs_batch_dimensions_size();
+ int64 col_dim = dim_nums.lhs_batch_dimensions_size() + 1;
+ int64 batch_size = std::accumulate(output_shape_.dimensions().begin(),
+ output_shape_.dimensions().end() - 2, 1,
+ std::multiplies<int64>());
+
+ // Check that the batch dims don't cover the last two dims.
+ for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
+ CHECK_NE(row_dim, batch_dim);
+ CHECK_NE(col_dim, batch_dim);
+ }
+
+ // Verify that the non-batch dimensions are minor-most. This is required for
+ // efficient access.
+ for (const auto* shape : {&lhs_shape_, &rhs_shape_, &output_shape_}) {
+ CHECK_LT(shape->layout().minor_to_major(row_dim), 2);
+ CHECK_LT(shape->layout().minor_to_major(col_dim), 2);
+ }
+
// BLAS gemm reduces rows of LHS and columns of RHS. The Dot operator between
// matrices reduces dimension 1 of LHS and dimension 0 of RHS regardless of
// their layout. Therefore, we should treat dimension 0 as row and dimension 1
// as column when mapping a matrix Dot to BLAS gemm.
- int64 output_num_rows = output_shape_.dimensions(0);
- int64 output_num_cols = output_shape_.dimensions(1);
+ int64 output_num_rows = output_shape_.dimensions(row_dim);
+ int64 output_num_cols = output_shape_.dimensions(col_dim);
// BLAS gemm expects the inputs and the output are in column-major order.
// Therefore, we need to convert dot between row-major matrices to that
@@ -298,31 +380,37 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
// the leading dimension of the LHS matrix of gemm is the number of rows in
// B^T and thus the number of columns in B.
- auto make_descriptor = [this](se::DeviceMemoryBase data, const Shape& shape,
- bool transpose) -> MatrixDescriptor {
- bool is_row_major = LayoutUtil::Minor(shape.layout(), 0) != 0;
- bool layout_mismatch = LayoutUtil::Minor(shape.layout(), 0) !=
- LayoutUtil::Minor(output_shape_.layout(), 0);
- return MatrixDescriptor(data, transpose ^ layout_mismatch,
- shape.dimensions(is_row_major),
- shape.dimensions(!is_row_major));
+ auto make_descriptor = [&](se::DeviceMemoryBase data, const Shape& shape,
+ bool transpose) -> MatrixDescriptor {
+ bool is_row_major = LayoutUtil::Minor(shape.layout(), row_dim) != 0;
+ bool layout_mismatch = LayoutUtil::Minor(shape.layout(), row_dim) !=
+ LayoutUtil::Minor(output_shape_.layout(), row_dim);
+ return MatrixDescriptor(
+ data, transpose ^ layout_mismatch,
+ shape.dimensions(row_dim + static_cast<int64>(is_row_major)),
+ shape.dimensions(row_dim + static_cast<int64>(!is_row_major)),
+ batch_size);
};
- DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
-
const MatrixDescriptor lhs_descriptor = make_descriptor(
- lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
+ lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == row_dim);
const MatrixDescriptor rhs_descriptor = make_descriptor(
- rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == 1);
+ rhs_data, rhs_shape_, dim_nums.rhs_contracting_dimensions(0) == col_dim);
// Dispatches to a regular cublas gemm, a gemm-with-algorithm, or attempts to
// autotune this gemm to figure out the best algorithm.
- auto launch = [this](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
- MatrixDescriptor output_matrix, se::Stream* stream) {
+ auto launch = [&](MatrixDescriptor lhs_matrix, MatrixDescriptor rhs_matrix,
+ MatrixDescriptor output_matrix, se::Stream* stream) {
PrimitiveType element_type = output_shape_.element_type();
se::blas::ComputationType computation_type =
GetBlasComputationType(element_type);
+ // TODO(b/112111608): Implement auto tune for batched gemm.
+ if (batch_size != 1) {
+ return GetGemmFn(element_type)(lhs_matrix, rhs_matrix, output_matrix,
+ alpha_, stream);
+ }
+
auto thunk_name = [&] {
return hlo_instruction() != nullptr ? hlo_instruction()->ToString()
: "<null>";
@@ -368,16 +456,16 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
auto op_profiler = profiler->MakeScopedInstructionProfiler(hlo_instruction());
bool launch_ok;
- if (LayoutUtil::Minor(output_shape_.layout(), 0) == 0) {
- launch_ok = launch(
- lhs_descriptor, rhs_descriptor,
- MatrixDescriptor(output_data, false, output_num_rows, output_num_cols),
- stream);
+ if (LayoutUtil::Minor(output_shape_.layout(), row_dim) == 0) {
+ launch_ok = launch(lhs_descriptor, rhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_rows,
+ output_num_cols, batch_size),
+ stream);
} else {
- launch_ok = launch(
- rhs_descriptor, lhs_descriptor,
- MatrixDescriptor(output_data, false, output_num_cols, output_num_rows),
- stream);
+ launch_ok = launch(rhs_descriptor, lhs_descriptor,
+ MatrixDescriptor(output_data, false, output_num_cols,
+ output_num_rows, batch_size),
+ stream);
}
if (!launch_ok) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 6ac5dfbcd5..d033faee8d 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -176,6 +176,38 @@ Status GpuLayoutAssignment::AddBackendConstraints(
TF_RETURN_IF_ERROR(
AddBackendConstraintsToDnnConvCustomCall(instruction, constraints));
}
+
+ // For batched dot we require the default layout.
+ // TODO(b/112111608): This is overly conservative, the only real restriction
+ // is that batch dimensions must be major.
+ if (instruction->opcode() == HloOpcode::kDot &&
+ ImplementedAsGemm(*instruction) &&
+ instruction->dot_dimension_numbers().lhs_batch_dimensions_size() > 0) {
+ // Verify that the batch dims come before the row and col dims.
+ const DotDimensionNumbers& dim_nums =
+ instruction->dot_dimension_numbers();
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size(),
+ dim_nums.rhs_batch_dimensions_size());
+ CHECK_EQ(dim_nums.lhs_batch_dimensions_size() + 2,
+ ShapeUtil::Rank(instruction->shape()));
+ for (int64 batch_dim : dim_nums.lhs_batch_dimensions()) {
+ CHECK_LT(batch_dim, ShapeUtil::Rank(instruction->shape()) - 2);
+ }
+
+ // Set both inputs and the output to default layout.
+ Shape op0_shape = instruction->operand(0)->shape();
+ LayoutUtil::SetToDefaultLayout(&op0_shape);
+ Shape op1_shape = instruction->operand(1)->shape();
+ LayoutUtil::SetToDefaultLayout(&op1_shape);
+ Shape output_shape = instruction->shape();
+ LayoutUtil::SetToDefaultLayout(&output_shape);
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(op0_shape, instruction, 0));
+ TF_RETURN_IF_ERROR(
+ constraints->SetOperandLayout(op1_shape, instruction, 1));
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(output_shape, instruction));
+ }
}
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
index 95f78ae293..286547ebae 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment_test.cc
@@ -20,8 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_matchers.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
@@ -31,6 +33,8 @@ namespace xla {
namespace gpu {
namespace {
+namespace op = xla::testing::opcode_matchers;
+
using LayoutAssignmentTest = HloTestBase;
TEST_F(LayoutAssignmentTest, Elementwise) {
@@ -327,6 +331,33 @@ TEST_F(LayoutAssignmentTest, BatchNormGrad) {
}
}
+TEST_F(LayoutAssignmentTest, DotLayout) {
+ const char* hlo_text = R"(
+ HloModule DotLayout
+ ENTRY dot {
+ p0 = f32[8,8,256,64]{3,1,2,0} parameter(0)
+ p1 = f32[8,8,256,64]{3,1,2,0} parameter(1)
+ ROOT dot.1330.10585 = f32[8,8,256,256]{3,2,1,0} dot(p0, p1),
+ lhs_batch_dims={0,1}, lhs_contracting_dims={3},
+ rhs_batch_dims={0,1}, rhs_contracting_dims={3}
+ })";
+
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(hlo_text));
+
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ GpuLayoutAssignment layout_assignment(&computation_layout,
+ backend().default_stream_executor());
+ EXPECT_TRUE(layout_assignment.Run(module.get()).ValueOrDie());
+
+ Shape expected_shape =
+ ShapeUtil::MakeShapeWithLayout(F32, {8, 8, 256, 64}, {3, 2, 1, 0});
+ EXPECT_THAT(module->entry_computation()->root_instruction(),
+ op::Dot(op::ShapeWithLayout(expected_shape),
+ op::ShapeWithLayout(expected_shape)));
+}
+
} // namespace
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 6352b330d1..d74c1a0243 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -38,14 +38,16 @@ namespace gpu {
namespace {
// Return whether the given shape is a matrix with no padding.
-bool IsRank2WithNoPadding(const Shape& shape) {
- return ShapeUtil::Rank(shape) == 2 && !LayoutUtil::IsPadded(shape);
+bool IsRank2WithNoPadding(const Shape& shape, int64 batch_dimensions_size) {
+ return ShapeUtil::Rank(shape) == batch_dimensions_size + 2 &&
+ !LayoutUtil::IsPadded(shape);
}
// In a gemm operation where output = lhs * rhs, check whether the given shapes
// are valid for the operation.
bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
- const Shape& output_shape) {
+ const Shape& output_shape,
+ int64 batch_dimensions_size) {
// The inputs and the output must
// 1) be matrices with no padding and a non-zero number of elements,
// 2) have an allowed element type.
@@ -53,9 +55,10 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
bool type_is_allowed =
(output_primitive_type == F16 || output_primitive_type == F32 ||
output_primitive_type == F64);
- return type_is_allowed && IsRank2WithNoPadding(lhs_shape) &&
- IsRank2WithNoPadding(rhs_shape) &&
- IsRank2WithNoPadding(output_shape) &&
+ return type_is_allowed &&
+ IsRank2WithNoPadding(lhs_shape, batch_dimensions_size) &&
+ IsRank2WithNoPadding(rhs_shape, batch_dimensions_size) &&
+ IsRank2WithNoPadding(output_shape, batch_dimensions_size) &&
!ShapeUtil::IsZeroElementArray(lhs_shape) &&
!ShapeUtil::IsZeroElementArray(rhs_shape);
}
@@ -64,14 +67,15 @@ bool DotImplementedAsGemm(const HloInstruction& dot) {
CHECK_EQ(dot.opcode(), HloOpcode::kDot);
const Shape& lhs_shape = dot.operand(0)->shape();
const Shape& rhs_shape = dot.operand(1)->shape();
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
// If gemm can accept the operand shapes, use it rather than a custom
// kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape(),
+ dim_numbers.lhs_batch_dimensions_size())) {
// The size of the reduction dimension should match. The shape inference
// guarantees this invariant, so the check here is for programming
// errors.
- const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
return true;
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
index 290e2f73dc..541cacf697 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter.cc
@@ -454,6 +454,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const Shape& lhs_shape = lhs_instruction->shape();
const Shape& rhs_shape = rhs_instruction->shape();
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ CHECK_EQ(dnums.lhs_batch_dimensions_size(),
+ dnums.rhs_batch_dimensions_size());
// TODO(b/110211620): Convert to use i32 index_type when it is possible.
llvm::Type* index_type = b_.getInt64Ty();
@@ -489,9 +492,15 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
const int64 lhs_reduction_dimension =
ShapeUtil::GetDimensionNumber(lhs_shape, -1);
const int64 rhs_reduction_dimension =
- ShapeUtil::Rank(rhs_shape) >= 2
+ ShapeUtil::Rank(rhs_shape) >= 2 + dnums.lhs_batch_dimensions_size()
? ShapeUtil::GetDimensionNumber(rhs_shape, -2)
- : 0;
+ : dnums.lhs_batch_dimensions_size();
+
+ // Check that the batch dims don't cover the last two dims.
+ for (int64 batch_dim : dnums.lhs_batch_dimensions()) {
+ CHECK_NE(lhs_reduction_dimension, batch_dim);
+ CHECK_NE(rhs_reduction_dimension, batch_dim);
+ }
// Verify the reduction dimension in the two operands are the same size.
TF_RET_CHECK(lhs_shape.dimensions(lhs_reduction_dimension) ==
@@ -506,6 +515,13 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
llvm_ir::IrArray::Index rhs_index = loop_nest.EmitOperandArrayLoopNest(
rhs_array, /*dimension_to_skip=*/rhs_reduction_dimension, "rhs");
+ // We don't have to iterate over the batch dimensions in both arrays, simplify
+ // the loop nest of the rhs.
+ for (int i = 0; i != dnums.lhs_batch_dimensions_size(); ++i) {
+ DCHECK(c_linear_search(dnums.lhs_batch_dimensions(), i));
+ rhs_index[i] = lhs_index[i];
+ }
+
// Create the reduction loop which does the sum of products reduction.
std::unique_ptr<llvm_ir::ForLoop> reduction_loop = loop_nest.AddLoop(
/*start_index=*/0,
@@ -568,7 +584,9 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
target_index.push_back(lhs_index[dimension]);
}
}
- for (size_t dimension = 0; dimension < rhs_index.size(); ++dimension) {
+ // Skip over the batch dimensions to not have them in the index twice.
+ for (size_t dimension = dnums.lhs_batch_dimensions_size();
+ dimension < rhs_index.size(); ++dimension) {
if (dimension != rhs_reduction_dimension) {
target_index.push_back(rhs_index[dimension]);
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index f61a977ad4..d5ecae88ed 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -345,11 +345,6 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
- const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
- if (dnums.lhs_batch_dimensions_size() > 0 ||
- dnums.rhs_batch_dimensions_size() > 0) {
- return Unimplemented("Dot with batch dimensions not implemented.");
- }
if (ImplementedAsGemm(*dot)) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index 7a683ede54..8fa0439006 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/conditional_simplifier.h"
-#include "tensorflow/compiler/xla/service/dot_decomposer.h"
#include "tensorflow/compiler/xla/service/flatten_call_graph.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h"
#include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h"
@@ -148,7 +147,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// support BF16 operations without directly implementing a BF16 lowering for
// most ops.
pipeline.AddPass<HloElementTypeConverter>(BF16, F32);
- pipeline.AddPass<DotDecomposer>();
{
auto& pass =
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index cfd36abf47..f11d274aab 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -612,7 +612,7 @@ XLA_TYPED_TEST(DotOperationTestForBatchMatMul, Types) {
{x_data.get(), y_data.get()}, this->error_spec_);
}
-XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMul) {
using T = TypeParam;
XlaBuilder builder(this->TestName());
@@ -648,6 +648,48 @@ XLA_TYPED_TEST(DotOperationTest_F16F32F64, GeneralMatMul) {
{x_data.get(), y_data.get()}, this->error_spec_);
}
+XLA_TYPED_TEST(DotOperationTest_F16F32F64CF64, GeneralMatMulMultipleBatch) {
+ using T = TypeParam;
+
+ XlaBuilder builder(this->TestName());
+ auto x = Parameter(&builder, 0, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "x");
+ auto y = Parameter(&builder, 1, ShapeUtil::MakeShapeWithType<T>({2, 2, 2, 2}),
+ "y");
+
+ DotDimensionNumbers dnums;
+ dnums.add_lhs_contracting_dimensions(3);
+ dnums.add_rhs_contracting_dimensions(2);
+ dnums.add_lhs_batch_dimensions(0);
+ dnums.add_lhs_batch_dimensions(1);
+ dnums.add_rhs_batch_dimensions(0);
+ dnums.add_rhs_batch_dimensions(1);
+
+ DotGeneral(x, y, dnums);
+
+ auto x_data =
+ this->client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ {{{9.0f, 10.0f}, {11.0f, 12.0f}},
+ {{13.0f, 14.0f}, {15.0f, 16.0f}}}}))
+ .ConsumeValueOrDie();
+
+ auto y_data =
+ this->client_
+ ->TransferToServer(*LiteralUtil::CreateR4FromArray4D<T>(
+ {{{{1.0f, 0.0f}, {0.0f, 1.0f}}, {{1.0f, 0.0f}, {0.0f, 1.0f}}},
+ {{{0.0f, 1.0f}, {1.0f, 0.0f}}, {{0.0f, 1.0f}, {1.0f, 0.0f}}}}))
+ .ConsumeValueOrDie();
+
+ this->template ComputeAndCompareR4<T>(
+ &builder,
+ /*expected=*/
+ {{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}},
+ {{{10.0f, 9.0f}, {12.0f, 11.0f}}, {{14.0f, 13.0f}, {16.0f, 15.0f}}}},
+ {x_data.get(), y_data.get()}, this->error_spec_);
+}
+
XLA_TYPED_TEST(DotOperationTest_F16F32F64, TransposeFolding) {
using T = TypeParam;
for (bool transpose_lhs : {false, true}) {