aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc40
6 files changed, 180 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 7cb7f55073..7ee039b3eb 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -388,8 +388,10 @@ cc_library(
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 2ebb40a44e..79fca43d02 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -215,6 +215,25 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
}
}
+DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
+ if (hlo_instruction.opcode() == HloOpcode::kDot) {
+ return hlo_instruction.dot_dimension_numbers();
+ }
+ CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
+ CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
+ CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(),
+ HloOpcode::kMultiply);
+ // Try to find the dot inside the output fusion node.
+ const HloInstruction* dot =
+ hlo_instruction.fused_expression_root()->operand(0);
+ if (dot->opcode() != HloOpcode::kDot) {
+ dot = hlo_instruction.fused_expression_root()->operand(1);
+ }
+ CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+
+ return dot->dot_dimension_numbers();
+}
+
} // namespace
GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
@@ -281,8 +300,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
shape.dimensions(!is_row_major));
};
- const DotDimensionNumbers& dim_nums =
- hlo_instruction()->dot_dimension_numbers();
+ DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
const MatrixDescriptor lhs_descriptor = make_descriptor(
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index c5eb721185..5d5bef6b57 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
@@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kTranspose;
}
+bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
+ if (constant->opcode() != HloOpcode::kConstant ||
+ !ShapeUtil::IsScalar(constant->shape())) {
+ return false;
+ }
+ auto type = constant->shape().element_type();
+ return type == F16 || type == F32 || type == F64;
+}
+
} // namespace
/*static*/ bool GpuInstructionFusion::IsExpensive(
@@ -66,34 +77,71 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (producer->opcode() == HloOpcode::kDot) {
- if (consumer->opcode() == HloOpcode::kMultiply) {
- CHECK_EQ(consumer->operand_count(), 2);
- int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
- if (alpha->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(alpha->shape())) {
+ if (consumer->operand_count() == 2 &&
+ (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot))) {
+ int64 other_operand_index = 1 - operand_index;
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
+ HloInstruction* op1 = nullptr;
+ HloInstruction* op2 = nullptr;
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ Match(consumer->fused_expression_root(),
+ match::Op()
+ .WithOpcode(HloOpcode::kMultiply)
+ .WithOperand(0, match::Op(&op1))
+ .WithOperand(1, match::Op(&op2)))) {
+ CHECK(op1 != nullptr && op2 != nullptr);
+ // If 'consumer' is a fusion node, it should consist of a broadcast of a
+ // scalar constant fused into a multiply, but nothing more. So one operand
+ // should be a parameter, and the other should be a broadcast.
+ if (op1->opcode() != HloOpcode::kParameter) {
+ std::swap(op1, op2);
+ }
+ if (op1->opcode() != HloOpcode::kParameter ||
+ op2->opcode() != HloOpcode::kBroadcast) {
+ return false;
+ }
+ if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ return true;
+ }
+ } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ // Fuse if 'alpha' is a broadcast of a scalar constant.
+ if (alpha->opcode() == HloOpcode::kBroadcast &&
+ alpha->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
return true;
}
}
}
- // Only allow to fuse transpose into an output fusion.
+ // Only allow fusing transpose or broadcast into an output fusion that is
+ // implemented as a Gemm call.
if (consumer->opcode() == HloOpcode::kFusion &&
- consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) {
- if (producer->opcode() != HloOpcode::kTranspose) {
- return false;
- }
- // Check that the transpose is the operand of a dot.
+ consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ ImplementedAsGemm(*consumer)) {
auto producer_operand_index = consumer->operand_index(producer);
auto fused_parameter = consumer->fused_parameter(producer_operand_index);
const std::vector<HloInstruction*>& fused_parameter_users =
fused_parameter->users();
- return (fused_parameter_users.size() == 1 &&
- fused_parameter_users[0]->opcode() == HloOpcode::kDot);
+ if (fused_parameter_users.size() != 1) {
+ return false;
+ }
+ if (producer->opcode() == HloOpcode::kTranspose) {
+ // Check that the transpose is an operand of a dot.
+ return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
+ }
+ if (producer->opcode() == HloOpcode::kBroadcast) {
+ // Check that the broadcast is a broadcast of a scalar constant into a
+ // multiply.
+ return producer->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
+ fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
+ }
}
- // Output fusion is not currently supported on GPUs.
+ // Other output fusions are not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
@@ -134,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
if (IsReductionToVector(*consumer)) {
return HloInstruction::FusionKind::kInput;
}
- if (producer->opcode() == HloOpcode::kDot) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
return HloInstruction::FusionKind::kOutput;
}
if (HloOpcode::kFusion == consumer->opcode()) {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 6c9a805ad6..760e0e90f5 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
@@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
auto module = tools::Parse(R"(
HloModule test_module
ENTRY OutputFusion {
- constant = f32[] constant(3)
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
- dot = f32[4,4]{1,0} dot(p0, transpose)
- ROOT mul = f32[4,4] multiply(constant, dot)
+ dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT mul = f32[4,4] multiply(dot, broadcast)
})")
.ValueOrDie();
@@ -247,10 +248,11 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
EXPECT_THAT(
root->fused_expression_root(),
- op::Multiply(op::Parameter(),
- op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
+ op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
+ op::Broadcast(op::Parameter())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
@@ -309,5 +311,31 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
.ValueOrDie());
}
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY NoOutputFusion {
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[3,4]{1,0} parameter(1)
+ dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ d = f32[4,4]{1,0} multiply(dot, dot)
+ ROOT mul = f32[4,4] multiply(d, broadcast)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
+ op::Broadcast(op::Parameter())));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 96199035b9..22e7150995 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
!ShapeUtil::HasZeroElements(lhs_shape) &&
!ShapeUtil::HasZeroElements(rhs_shape);
}
+
+bool DotImplementedAsGemm(const HloInstruction& dot) {
+ CHECK_EQ(dot.opcode(), HloOpcode::kDot);
+ const Shape& lhs_shape = dot.operand(0)->shape();
+ const Shape& rhs_shape = dot.operand(1)->shape();
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
+ return true;
+ }
+ return false;
+}
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
@@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
// For certain types of Dot, we can call pre-canned BLAS gemm.
if (hlo.opcode() == HloOpcode::kDot) {
- const Shape& lhs_shape = hlo.operand(0)->shape();
- const Shape& rhs_shape = hlo.operand(1)->shape();
-
- // If gemm can accept the operand shapes, use it rather than a custom
- // kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
- // The size of the reduction dimension should match. The shape inference
- // guarantees this invariant, so the check here is for programming
- // errors.
- const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
- CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
- rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
- return true;
- }
+ return DotImplementedAsGemm(hlo);
}
if (hlo.opcode() == HloOpcode::kFusion &&
@@ -94,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
dot = hlo.fused_expression_root()->operand(1);
}
if (dot->opcode() == HloOpcode::kDot) {
- return ImplementedAsGemm(*dot);
+ return DotImplementedAsGemm(*dot);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 83d90296df..0d7ba4cf9a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2194,6 +2194,21 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
/*destination_buffer=*/GetAllocationSlice(*inst), inst);
}
+namespace {
+double GetScalarConstantAsDouble(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case F16:
+ return static_cast<double>(literal.Get<Eigen::half>({}));
+ case F32:
+ return literal.Get<float>({});
+ case F64:
+ return literal.Get<double>({});
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+} // namespace
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kDot) {
@@ -2218,6 +2233,17 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (dot->opcode() != HloOpcode::kDot) {
std::swap(dot, alpha);
}
+ if (alpha->opcode() == HloOpcode::kBroadcast) {
+ alpha = alpha->operand(0);
+ }
+ alpha = inst->operand(alpha->parameter_number());
+ // TODO(b/74185543): Remove the following if block once we support fusion
+ // with a non-constant as well. Then we will just always use the constant
+ // on the device.
+ if (alpha->opcode() == HloOpcode::kCopy) {
+ alpha = alpha->operand(0);
+ }
+
DCHECK(dot->opcode() == HloOpcode::kDot);
const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
@@ -2229,13 +2255,13 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
inst->operand(rhs_parameter->parameter_number());
return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- alpha->literal().Get<double>({0}), // alpha.
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ GetScalarConstantAsDouble(alpha->literal()), // alpha.
inst);
}