aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-05-15 01:22:13 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-15 01:24:49 -0700
commitaf06f858edec0499d13561e9c5a9867a28833c5d (patch)
treec6eb511f6ffcf0f11f8e525a776ecce4647e1bff
parent3d84d0691c321f4c8539dbe2c61ab66cda4d18b4 (diff)
Reland improve fusion logic of (a dot b) * alpha
The previous fusion approach didn't work because a multiplication by a scalar value will be changed into an explicit broadcast. Another issue that is fixed in this CL is retrieving the constant value from the literal. This depends on the PrimitiveType, before we always assumed it to be double. Also when checking ImplementedAsGemm() we should not call it recursively, but instead just the check related to kDot. Finally add an execution test and adjust the fusion logic test. The fix for the issue that caused the revert is that we check earlier that consumer->operand_count() is 2. Also, we fix the call to Get() to pass {} instead of {0}. And we handle an output fusion node in GemmThunk to extract the dimension numbers from the dot operation. PiperOrigin-RevId: 196631031
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gemm_thunk.cc22
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc84
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc46
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc36
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc40
6 files changed, 180 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 7cb7f55073..7ee039b3eb 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -388,8 +388,10 @@ cc_library(
deps = [
":ir_emission_utils",
"//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
+ "//tensorflow/compiler/xla/service:pattern_matcher",
],
)
diff --git a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
index 2ebb40a44e..79fca43d02 100644
--- a/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
+++ b/tensorflow/compiler/xla/service/gpu/gemm_thunk.cc
@@ -215,6 +215,25 @@ se::blas::ComputationType GetBlasComputationType(PrimitiveType type) {
}
}
+DotDimensionNumbers GetDimensionNumbers(const HloInstruction& hlo_instruction) {
+ if (hlo_instruction.opcode() == HloOpcode::kDot) {
+ return hlo_instruction.dot_dimension_numbers();
+ }
+ CHECK_EQ(hlo_instruction.opcode(), HloOpcode::kFusion);
+ CHECK_EQ(hlo_instruction.fusion_kind(), HloInstruction::FusionKind::kOutput);
+ CHECK_EQ(hlo_instruction.fused_expression_root()->opcode(),
+ HloOpcode::kMultiply);
+ // Try to find the dot inside the output fusion node.
+ const HloInstruction* dot =
+ hlo_instruction.fused_expression_root()->operand(0);
+ if (dot->opcode() != HloOpcode::kDot) {
+ dot = hlo_instruction.fused_expression_root()->operand(1);
+ }
+ CHECK_EQ(dot->opcode(), HloOpcode::kDot);
+
+ return dot->dot_dimension_numbers();
+}
+
} // namespace
GemmThunk::GemmThunk(const BufferAllocation::Slice& lhs_buffer,
@@ -281,8 +300,7 @@ Status GemmThunk::ExecuteOnStream(const BufferAllocations& buffer_allocations,
shape.dimensions(!is_row_major));
};
- const DotDimensionNumbers& dim_nums =
- hlo_instruction()->dot_dimension_numbers();
+ DotDimensionNumbers dim_nums = GetDimensionNumbers(*hlo_instruction());
const MatrixDescriptor lhs_descriptor = make_descriptor(
lhs_data, lhs_shape_, dim_nums.lhs_contracting_dimensions(0) == 0);
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index c5eb721185..5d5bef6b57 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/service/pattern_matcher.h"
#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
namespace xla {
namespace gpu {
@@ -46,6 +48,15 @@ bool IsFusile(const HloInstruction& hlo) {
hlo.opcode() == HloOpcode::kTranspose;
}
+bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
+ if (constant->opcode() != HloOpcode::kConstant ||
+ !ShapeUtil::IsScalar(constant->shape())) {
+ return false;
+ }
+ auto type = constant->shape().element_type();
+ return type == F16 || type == F32 || type == F64;
+}
+
} // namespace
/*static*/ bool GpuInstructionFusion::IsExpensive(
@@ -66,34 +77,71 @@ bool GpuInstructionFusion::ShouldFuse(HloInstruction* consumer,
HloInstruction* producer = consumer->mutable_operand(operand_index);
// Check if we can use output fusion for (A @ B) * alpha
- if (producer->opcode() == HloOpcode::kDot) {
- if (consumer->opcode() == HloOpcode::kMultiply) {
- CHECK_EQ(consumer->operand_count(), 2);
- int64 other_operand_index = 1 - operand_index;
- const HloInstruction* alpha = consumer->operand(other_operand_index);
- if (alpha->opcode() == HloOpcode::kConstant &&
- ShapeUtil::IsScalar(alpha->shape())) {
+ if (consumer->operand_count() == 2 &&
+ (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot))) {
+ int64 other_operand_index = 1 - operand_index;
+ const HloInstruction* alpha = consumer->operand(other_operand_index);
+ HloInstruction* op1 = nullptr;
+ HloInstruction* op2 = nullptr;
+ if (consumer->opcode() == HloOpcode::kFusion &&
+ consumer->fusion_kind() == HloInstruction::FusionKind::kLoop &&
+ Match(consumer->fused_expression_root(),
+ match::Op()
+ .WithOpcode(HloOpcode::kMultiply)
+ .WithOperand(0, match::Op(&op1))
+ .WithOperand(1, match::Op(&op2)))) {
+ CHECK(op1 != nullptr && op2 != nullptr);
+ // If 'consumer' is a fusion node, it should consist of a broadcast of a
+ // scalar constant fused into a multiply, but nothing more. So one operand
+ // should be a parameter, and the other should be a broadcast.
+ if (op1->opcode() != HloOpcode::kParameter) {
+ std::swap(op1, op2);
+ }
+ if (op1->opcode() != HloOpcode::kParameter ||
+ op2->opcode() != HloOpcode::kBroadcast) {
+ return false;
+ }
+ if (IsIEEEFloatingPointScalarConstant(alpha)) {
+ return true;
+ }
+ } else if (consumer->opcode() == HloOpcode::kMultiply) {
+ // Fuse if 'alpha' is a broadcast of a scalar constant.
+ if (alpha->opcode() == HloOpcode::kBroadcast &&
+ alpha->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(alpha->operand(0))) {
return true;
}
}
}
- // Only allow to fuse transpose into an output fusion.
+ // Only allow fusing transpose or broadcast into an output fusion that is
+ // implemented as a Gemm call.
if (consumer->opcode() == HloOpcode::kFusion &&
- consumer->fusion_kind() == HloInstruction::FusionKind::kOutput) {
- if (producer->opcode() != HloOpcode::kTranspose) {
- return false;
- }
- // Check that the transpose is the operand of a dot.
+ consumer->fusion_kind() == HloInstruction::FusionKind::kOutput &&
+ ImplementedAsGemm(*consumer)) {
auto producer_operand_index = consumer->operand_index(producer);
auto fused_parameter = consumer->fused_parameter(producer_operand_index);
const std::vector<HloInstruction*>& fused_parameter_users =
fused_parameter->users();
- return (fused_parameter_users.size() == 1 &&
- fused_parameter_users[0]->opcode() == HloOpcode::kDot);
+ if (fused_parameter_users.size() != 1) {
+ return false;
+ }
+ if (producer->opcode() == HloOpcode::kTranspose) {
+ // Check that the transpose is an operand of a dot.
+ return fused_parameter_users[0]->opcode() == HloOpcode::kDot;
+ }
+ if (producer->opcode() == HloOpcode::kBroadcast) {
+ // Check that the broadcast is a broadcast of a scalar constant into a
+ // multiply.
+ return producer->dimensions().empty() &&
+ IsIEEEFloatingPointScalarConstant(producer->operand(0)) &&
+ fused_parameter_users[0]->opcode() == HloOpcode::kMultiply;
+ }
}
- // Output fusion is not currently supported on GPUs.
+ // Other output fusions are not currently supported on GPUs.
if (producer->opcode() == HloOpcode::kFusion) {
return false;
}
@@ -134,7 +182,9 @@ HloInstruction::FusionKind GpuInstructionFusion::ChooseKind(
if (IsReductionToVector(*consumer)) {
return HloInstruction::FusionKind::kInput;
}
- if (producer->opcode() == HloOpcode::kDot) {
+ if (producer->opcode() == HloOpcode::kDot ||
+ (producer->opcode() == HloOpcode::kFusion &&
+ producer->fused_expression_root()->opcode() == HloOpcode::kDot)) {
return HloInstruction::FusionKind::kOutput;
}
if (HloOpcode::kFusion == consumer->opcode()) {
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
index 6c9a805ad6..760e0e90f5 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion_test.cc
@@ -108,8 +108,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastReshapeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto reshape2 = builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1, 1, 1}), dot1));
@@ -125,8 +125,8 @@ TEST_F(InstructionFusionTest, PotentialBitcastTransposeOfDotUnfused) {
HloComputation::Builder builder(TestName());
auto param0 = builder.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(S32, {1, 1}), "0"));
- auto dot1 = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(S32, {1, 1}), HloOpcode::kDot, param0, param0));
+ auto dot1 = builder.AddInstruction(HloInstruction::CreateCanonicalDot(
+ ShapeUtil::MakeShape(S32, {1, 1}), param0, param0));
auto transpose2 = builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {1, 1}), dot1, {0, 1}));
@@ -232,12 +232,13 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
auto module = tools::Parse(R"(
HloModule test_module
ENTRY OutputFusion {
- constant = f32[] constant(3)
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
p0 = f32[4,3]{1,0} parameter(0)
p1 = f32[4,3]{1,0} parameter(1)
transpose = f32[3,4]{1,0} transpose(p1), dimensions={1, 0}
- dot = f32[4,4]{1,0} dot(p0, transpose)
- ROOT mul = f32[4,4] multiply(constant, dot)
+ dot = f32[4,4]{1,0} dot(p0, transpose), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ ROOT mul = f32[4,4] multiply(dot, broadcast)
})")
.ValueOrDie();
@@ -247,10 +248,11 @@ TEST_F(InstructionFusionTest, DotOutputFusion) {
HloInstruction* root = module->entry_computation()->root_instruction();
EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kOutput);
EXPECT_THAT(
root->fused_expression_root(),
- op::Multiply(op::Parameter(),
- op::Dot(op::Parameter(), op::Transpose(op::Parameter()))));
+ op::Multiply(op::Dot(op::Parameter(), op::Transpose(op::Parameter())),
+ op::Broadcast(op::Parameter())));
}
// Compute sum(1/p0), where p0 has type f32, twice. Check that the division is
@@ -309,5 +311,31 @@ TEST_F(InstructionFusionTest, IntegerDivIsNotCheap) {
.ValueOrDie());
}
+TEST_F(InstructionFusionTest, DotOutputFusionImpossible) {
+ auto module = tools::Parse(R"(
+ HloModule test_module
+ ENTRY NoOutputFusion {
+ alpha = f32[] constant(3)
+ broadcast = f32[4,4]{1,0} broadcast(alpha), dimensions={}
+ p0 = f32[4,3]{1,0} parameter(0)
+ p1 = f32[3,4]{1,0} parameter(1)
+ dot = f32[4,4]{1,0} dot(p0, p1), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ d = f32[4,4]{1,0} multiply(dot, dot)
+ ROOT mul = f32[4,4] multiply(d, broadcast)
+ })")
+ .ValueOrDie();
+
+ EXPECT_TRUE(GpuInstructionFusion(/*may_duplicate=*/true)
+ .Run(module.get())
+ .ValueOrDie());
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root, op::Fusion());
+ EXPECT_EQ(root->fusion_kind(), HloInstruction::FusionKind::kLoop);
+ EXPECT_THAT(root->fused_expression_root(),
+ op::Multiply(op::Multiply(op::Parameter(), op::Parameter()),
+ op::Broadcast(op::Parameter())));
+}
+
} // namespace gpu
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
index 96199035b9..22e7150995 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emission_utils.cc
@@ -59,6 +59,25 @@ bool AreValidGemmShapes(const Shape& lhs_shape, const Shape& rhs_shape,
!ShapeUtil::HasZeroElements(lhs_shape) &&
!ShapeUtil::HasZeroElements(rhs_shape);
}
+
+bool DotImplementedAsGemm(const HloInstruction& dot) {
+ CHECK_EQ(dot.opcode(), HloOpcode::kDot);
+ const Shape& lhs_shape = dot.operand(0)->shape();
+ const Shape& rhs_shape = dot.operand(1)->shape();
+
+ // If gemm can accept the operand shapes, use it rather than a custom
+ // kernel.
+ if (AreValidGemmShapes(lhs_shape, rhs_shape, dot.shape())) {
+ // The size of the reduction dimension should match. The shape inference
+ // guarantees this invariant, so the check here is for programming
+ // errors.
+ const DotDimensionNumbers& dim_numbers = dot.dot_dimension_numbers();
+ CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
+ rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
+ return true;
+ }
+ return false;
+}
} // namespace
bool ImplementedAsGemm(const HloInstruction& hlo) {
@@ -69,20 +88,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
// For certain types of Dot, we can call pre-canned BLAS gemm.
if (hlo.opcode() == HloOpcode::kDot) {
- const Shape& lhs_shape = hlo.operand(0)->shape();
- const Shape& rhs_shape = hlo.operand(1)->shape();
-
- // If gemm can accept the operand shapes, use it rather than a custom
- // kernel.
- if (AreValidGemmShapes(lhs_shape, rhs_shape, hlo.shape())) {
- // The size of the reduction dimension should match. The shape inference
- // guarantees this invariant, so the check here is for programming
- // errors.
- const DotDimensionNumbers& dim_numbers = hlo.dot_dimension_numbers();
- CHECK_EQ(lhs_shape.dimensions(dim_numbers.lhs_contracting_dimensions(0)),
- rhs_shape.dimensions(dim_numbers.rhs_contracting_dimensions(0)));
- return true;
- }
+ return DotImplementedAsGemm(hlo);
}
if (hlo.opcode() == HloOpcode::kFusion &&
@@ -94,7 +100,7 @@ bool ImplementedAsGemm(const HloInstruction& hlo) {
dot = hlo.fused_expression_root()->operand(1);
}
if (dot->opcode() == HloOpcode::kDot) {
- return ImplementedAsGemm(*dot);
+ return DotImplementedAsGemm(*dot);
}
}
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 83d90296df..0d7ba4cf9a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2194,6 +2194,21 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
/*destination_buffer=*/GetAllocationSlice(*inst), inst);
}
+namespace {
+double GetScalarConstantAsDouble(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case F16:
+ return static_cast<double>(literal.Get<Eigen::half>({}));
+ case F32:
+ return literal.Get<float>({});
+ case F64:
+ return literal.Get<double>({});
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+} // namespace
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kDot) {
@@ -2218,6 +2233,17 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (dot->opcode() != HloOpcode::kDot) {
std::swap(dot, alpha);
}
+ if (alpha->opcode() == HloOpcode::kBroadcast) {
+ alpha = alpha->operand(0);
+ }
+ alpha = inst->operand(alpha->parameter_number());
+ // TODO(b/74185543): Remove the following if block once we support fusion
+ // with a non-constant as well. Then we will just always use the constant
+ // on the device.
+ if (alpha->opcode() == HloOpcode::kCopy) {
+ alpha = alpha->operand(0);
+ }
+
DCHECK(dot->opcode() == HloOpcode::kDot);
const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
@@ -2229,13 +2255,13 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
inst->operand(rhs_parameter->parameter_number());
return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- alpha->literal().Get<double>({0}), // alpha.
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ GetScalarConstantAsDouble(alpha->literal()), // alpha.
inst);
}