diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc | 103 |
1 files changed, 56 insertions, 47 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc index 27d2c3e491..cc38db27e2 100644 --- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc @@ -29,12 +29,13 @@ limitations under the License. #include "llvm/IR/Intrinsics.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h" #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" +#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" @@ -67,8 +68,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) { GpuElementalIrEmitter::GpuElementalIrEmitter( const HloModuleConfig& hlo_module_config, llvm::Module* module, - llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested) - : ElementalIrEmitter(hlo_module_config, module, ir_builder), + llvm::IRBuilder<>* b, NestedComputer compute_nested) + : ElementalIrEmitter(hlo_module_config, module, b), hlo_module_config_(hlo_module_config), compute_nested_(std::move(compute_nested)) {} @@ -92,8 +93,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall( cast_result_to_fp16 = true; for (int64 i = 0; i < operands.size(); ++i) { if (input_types[i] == F16) { - converted_operands[i] = ir_builder_->CreateFPCast( - converted_operands[i], ir_builder_->getFloatTy()); + converted_operands[i] = + b_->CreateFPCast(converted_operands[i], b_->getFloatTy()); converted_input_types[i] = F32; } } @@ -112,7 +113,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall( converted_input_types, output_type) .ValueOrDie(); if (cast_result_to_fp16) { - result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy()); + result = b_->CreateFPCast(result, b_->getHalfTy()); } return result; } @@ -215,7 +216,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp( // LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX // rsqrt.approx instruction. TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt()); - return ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); + return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt); } VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString(); @@ -277,6 +278,16 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp( PrimitiveType output_type = op->shape().element_type(); switch (op->opcode()) { case HloOpcode::kTanh: + // If we don't care much about precision, emit a fast approximation of + // tanh. + if (hlo_module_config_.debug_options().xla_enable_fast_math()) { + // Upcast F16 to F32 if necessary. + llvm::Type* type = + input_type == F16 ? b_->getFloatTy() : operand_value->getType(); + llvm::Value* input = b_->CreateFPCast(operand_value, type); + llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input); + return b_->CreateFPCast(fast_tanh, operand_value->getType()); + } return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type}, output_type); default: @@ -302,32 +313,31 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall( // Declares the callee if it is not declared already. llvm::Function* callee = llvm::cast<llvm::Function>( - ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction( + b_->GetInsertBlock()->getModule()->getOrInsertFunction( llvm_ir::AsStringRef(callee_name), callee_type)); for (auto attribute : attributes) { callee->addFnAttr(attribute); } - return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); + return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands)); } llvm::Value* GpuElementalIrEmitter::EmitThreadId() const { - llvm::Value* block_id = ir_builder_->CreateIntCast( + llvm::Value* block_id = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id"); - llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "block.id"); + llvm::Value* thread_id_in_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id"); - llvm::Value* threads_per_block = ir_builder_->CreateIntCast( + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "thread.id"); + llvm::Value* threads_per_block = b_->CreateIntCast( llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x, - {}, {}, ir_builder_), - ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); - return ir_builder_->CreateNSWAdd( - ir_builder_->CreateNSWMul(block_id, threads_per_block), - thread_id_in_block); + {}, {}, b_), + b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block"); + return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block), + thread_id_in_block); } llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( @@ -373,12 +383,12 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( PrimitiveType operand_element_type = operand->shape().element_type(); llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry( llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_), - "reduce_window_accum_ptr", ir_builder_); + "reduce_window_accum_ptr", b_); { TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index.GetType()))); - ir_builder_->CreateStore(init_value, accum_ptr); + b_->CreateStore(init_value, accum_ptr); } llvm::Type* index_type = index.GetType(); @@ -386,7 +396,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( return index.GetConstantWithIndexType(c); }; - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); std::vector<int64> window_size; for (const auto& dim : window.dimensions()) { window_size.push_back(dim.size()); @@ -395,15 +405,15 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( ShapeUtil::MakeShape(operand_element_type, window_size), "window"); CHECK_EQ(window_index.size(), index.size()); - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_); IrArray::Index input_index(index_type, index.size()); - llvm::Value* in_bounds = ir_builder_->getInt1(true); + llvm::Value* in_bounds = b_->getInt1(true); for (size_t i = 0; i < index.size(); ++i) { - llvm::Value* stridden_index = ir_builder_->CreateNSWMul( + llvm::Value* stridden_index = b_->CreateNSWMul( index[i], index_typed_const(window.dimensions(i).stride())); - input_index[i] = ir_builder_->CreateNSWSub( - ir_builder_->CreateNSWAdd(stridden_index, window_index[i]), + input_index[i] = b_->CreateNSWSub( + b_->CreateNSWAdd(stridden_index, window_index[i]), index_typed_const(window.dimensions(i).padding_low())); // We must check whether 0 ≤ input_index[i] < bound, as otherwise @@ -411,16 +421,16 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( // comparison is equivalent to the unsigned comparison // input_index[i] < bound, as a negative value wraps to a large // positive value. - in_bounds = ir_builder_->CreateAnd( + in_bounds = b_->CreateAnd( in_bounds, - ir_builder_->CreateICmpULT( + b_->CreateICmpULT( input_index[i], index_typed_const(operand->shape().dimensions(i)))); } llvm_ir::LlvmIfData if_data = - llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_); - SetToFirstInsertPoint(if_data.true_block, ir_builder_); + llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_); + SetToFirstInsertPoint(if_data.true_block, b_); // We are not in pad, so do the computation. TF_ASSIGN_OR_RETURN(llvm::Value * input_value, @@ -428,26 +438,26 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, compute_nested_(*hlo->to_apply(), - {ir_builder_->CreateLoad(accum_ptr), input_value})); - ir_builder_->CreateStore(accum_value, accum_ptr); + {b_->CreateLoad(accum_ptr), input_value})); + b_->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_); - return ir_builder_->CreateLoad(accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_); + return b_->CreateLoad(accum_ptr); }; case HloOpcode::kReduce: return [=, &operand_to_generator]( const IrArray::Index& output_index) -> StatusOr<llvm::Value*> { const HloInstruction* operand = hlo->operand(0); llvm::Value* accum_ptr = - ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( + b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType( hlo->shape().element_type(), module_)); llvm::Type* index_type = output_index.GetType(); TF_ASSIGN_OR_RETURN(llvm::Value * init_value, operand_to_generator.at(hlo->operand(1))( IrArray::Index(index_type))); - ir_builder()->CreateStore(init_value, accum_ptr); + b()->CreateStore(init_value, accum_ptr); - llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type); + llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type); IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions( operand->shape(), hlo->dimensions(), "reduction_dim"); if (!ShapeUtil::IsScalar(hlo->shape())) { @@ -462,18 +472,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator( CHECK_EQ(output_index.size(), j); } - SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder()); + SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b()); TF_ASSIGN_OR_RETURN( llvm::Value * input_value, operand_to_generator.at(hlo->operand(0))(input_index)); TF_ASSIGN_OR_RETURN( llvm::Value * accum_value, - compute_nested_( - *hlo->to_apply(), - {ir_builder()->CreateLoad(accum_ptr), input_value})); - ir_builder()->CreateStore(accum_value, accum_ptr); - SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder()); - return ir_builder()->CreateLoad(accum_ptr); + compute_nested_(*hlo->to_apply(), + {b()->CreateLoad(accum_ptr), input_value})); + b()->CreateStore(accum_value, accum_ptr); + SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b()); + return b()->CreateLoad(accum_ptr); }; default: return ElementalIrEmitter::MakeElementGenerator(hlo, |