aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-04 22:04:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-07 15:40:20 -0700
commit150089e6e67e4492f098cdd8f9f2f48dc9f9cc56 (patch)
tree778d8f20ab300ceea85a36c22d150570ff9530f8 /tensorflow/compiler/xla/service/cpu/ir_emitter.cc
parent939fc534a4b2f227ee337e7dcfa82ec9b6337814 (diff)
Remove uses of the kTransposeDot fusion
I didn't remove the enum itself, but after this change removing the enum should be a simple NFC change (famous last words!). This will make it easier to implement BatchDot on CPU. The change removes usages of kTransposeDot by: - Teaching TransposeFolding to "fuse" transposes into dots by flipping the lhs_contracting_dims/rhs_contracting_dims fields. - Replacing the notion of transpose_lhs/transpose_rhs in the IR emitters with "has a non-canonical LHS contraction dimension"/"has a non-canonical RHS contraction dimension" where the canonical LHS and RHS contraction dims [0] are 1 and 0. Some tests were getting away with creating Dot instructions with their dimensions numbers unset. I've fixed these to create canonical dot operations instead. It is possible (but hard to tell without trying) that some of the IR emission logic and Eigen runtime calls can now be simplified further. For instance, instead of passing in a `transpose_lhs` and `transpose_rhs` to the Eigen GEMM routines, we could instead pass in the LHS and RHS contraction dimensions directly. [0] See HloInstruction::CreateCanonicalDot. PiperOrigin-RevId: 195514907
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc57
1 files changed, 7 insertions, 50 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 6347ee2a2a..12f50e00b5 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -827,13 +827,6 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
"Dot with multiple contracting dimensions not implemented.");
}
- if (dnums.lhs_contracting_dimensions(0) !=
- std::min(lhs->shape().dimensions_size() - 1, 1) ||
- dnums.rhs_contracting_dimensions(0) != 0) {
- return Unimplemented(
- "Dot with non-standard contracting dimensions not implemented.");
- }
-
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
@@ -850,8 +843,7 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
// Dot operation is complicated so we delegate to a helper class.
return DotOpEmitter::EmitDotOperation(
- *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
- lhs_array, rhs_array, /*addend_array=*/nullptr,
+ *dot, target_array, lhs_array, rhs_array, /*addend_array=*/nullptr,
GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
target_machine_features_);
}
@@ -2085,45 +2077,10 @@ static const HloInstruction* StripTranspose(const HloInstruction& hlo) {
}
Status IrEmitter::HandleFusion(HloInstruction* fusion) {
- auto* root = fusion->fused_expression_root();
- if (fusion->fusion_kind() == HloInstruction::FusionKind::kTransposeDot) {
- DCHECK(root->opcode() == HloOpcode::kDot);
- const HloInstruction* lhs_parameter = StripTranspose(*root->operand(0));
- const HloInstruction* rhs_parameter = StripTranspose(*root->operand(1));
- DCHECK(lhs_parameter->opcode() == HloOpcode::kParameter &&
- rhs_parameter->opcode() == HloOpcode::kParameter);
- const HloInstruction* lhs =
- fusion->operand(lhs_parameter->parameter_number());
- const HloInstruction* rhs =
- fusion->operand(rhs_parameter->parameter_number());
-
- TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
- /*instruction=*/*root, /*operands=*/{lhs, rhs},
- /*supported_types=*/{F16, F32, F64}));
-
- llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
- llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
+ CHECK_NE(fusion->fusion_kind(), HloInstruction::FusionKind::kTransposeDot);
- Shape target_shape = fusion->shape();
- TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
- llvm_ir::IrArray target_array = GetIrArrayFor(fusion);
- VLOG(2) << "HandleFusion kTransposeDot: ";
- VLOG(2) << " lhs operand: "
- << llvm_ir::DumpToString(*lhs_array.GetBasePointer());
- VLOG(2) << " rhs operand: "
- << llvm_ir::DumpToString(*rhs_array.GetBasePointer());
- VLOG(2) << " target: "
- << llvm_ir::DumpToString(*target_array.GetBasePointer());
-
- // Dot operation is complicated so we delegate to a helper class.
- TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
- *root, root->operand(0)->IsRank2Transpose(),
- root->operand(1)->IsRank2Transpose(), target_array, lhs_array,
- rhs_array, /*addend_array=*/nullptr, GetExecutableRunOptionsArgument(),
- &ir_builder_, hlo_module_config_, target_machine_features_));
- return Status::OK();
- } else if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion,
- assignment_)) {
+ auto* root = fusion->fused_expression_root();
+ if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(fusion, assignment_)) {
VLOG(3) << "HandleFusion FusedDynamicUpdateSliceInPlace";
CpuElementalIrEmitter elemental_emitter(hlo_module_config_, this, module_);
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(fusion));
@@ -2166,9 +2123,9 @@ Status IrEmitter::HandleFusion(HloInstruction* fusion) {
GetIrArrayFor(fusion->operand(addend_param_number)));
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
- *dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
- lhs_array, rhs_array, &addend_array, GetExecutableRunOptionsArgument(),
- &ir_builder_, hlo_module_config_, target_machine_features_));
+ *dot, target_array, lhs_array, rhs_array, &addend_array,
+ GetExecutableRunOptionsArgument(), &ir_builder_, hlo_module_config_,
+ target_machine_features_));
return Status::OK();
} else {
return Unimplemented("Fusion kind not implemented on CPU");