aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc40
1 files changed, 33 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 83d90296df..0d7ba4cf9a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2194,6 +2194,21 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
/*destination_buffer=*/GetAllocationSlice(*inst), inst);
}
+namespace {
+double GetScalarConstantAsDouble(const Literal& literal) {
+ switch (literal.shape().element_type()) {
+ case F16:
+ return static_cast<double>(literal.Get<Eigen::half>({}));
+ case F32:
+ return literal.Get<float>({});
+ case F64:
+ return literal.Get<double>({});
+ default:
+ LOG(FATAL) << "Unsupported type.";
+ }
+}
+} // namespace
+
std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
const HloInstruction* inst) {
if (inst->opcode() == HloOpcode::kDot) {
@@ -2218,6 +2233,17 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
if (dot->opcode() != HloOpcode::kDot) {
std::swap(dot, alpha);
}
+ if (alpha->opcode() == HloOpcode::kBroadcast) {
+ alpha = alpha->operand(0);
+ }
+ alpha = inst->operand(alpha->parameter_number());
+ // TODO(b/74185543): Remove the following if block once we support fusion
+ // with a non-constant as well. Then we will just always use the constant
+ // on the device.
+ if (alpha->opcode() == HloOpcode::kCopy) {
+ alpha = alpha->operand(0);
+ }
+
DCHECK(dot->opcode() == HloOpcode::kDot);
const HloInstruction* lhs_parameter = StripTranspose(*dot->operand(0));
const HloInstruction* rhs_parameter = StripTranspose(*dot->operand(1));
@@ -2229,13 +2255,13 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildGemmThunk(
inst->operand(rhs_parameter->parameter_number());
return MakeUnique<GemmThunk>(
- GetAllocationSlice(*lhs), // The buffer assigned to LHS.
- GetAllocationSlice(*rhs), // The buffer assigned to RHS.
- GetAllocationSlice(*mul), // The output buffer.
- lhs->shape(), // The shape of LHS.
- rhs->shape(), // The shape of RHS.
- inst->shape(), // The shape of the output.
- alpha->literal().Get<double>({0}), // alpha.
+ GetAllocationSlice(*lhs), // The buffer assigned to LHS.
+ GetAllocationSlice(*rhs), // The buffer assigned to RHS.
+ GetAllocationSlice(*inst), // The output buffer.
+ lhs->shape(), // The shape of LHS.
+ rhs->shape(), // The shape of RHS.
+ inst->shape(), // The shape of the output.
+ GetScalarConstantAsDouble(alpha->literal()), // alpha.
inst);
}