aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc103
1 files changed, 56 insertions, 47 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 27d2c3e491..cc38db27e2 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -29,12 +29,13 @@ limitations under the License.
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Type.h"
-#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -67,8 +68,8 @@ bool IsFPLiteralWithValue(const HloInstruction* operand, float value) {
GpuElementalIrEmitter::GpuElementalIrEmitter(
const HloModuleConfig& hlo_module_config, llvm::Module* module,
- llvm::IRBuilder<>* ir_builder, NestedComputer compute_nested)
- : ElementalIrEmitter(hlo_module_config, module, ir_builder),
+ llvm::IRBuilder<>* b, NestedComputer compute_nested)
+ : ElementalIrEmitter(hlo_module_config, module, b),
hlo_module_config_(hlo_module_config),
compute_nested_(std::move(compute_nested)) {}
@@ -92,8 +93,8 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
cast_result_to_fp16 = true;
for (int64 i = 0; i < operands.size(); ++i) {
if (input_types[i] == F16) {
- converted_operands[i] = ir_builder_->CreateFPCast(
- converted_operands[i], ir_builder_->getFloatTy());
+ converted_operands[i] =
+ b_->CreateFPCast(converted_operands[i], b_->getFloatTy());
converted_input_types[i] = F32;
}
}
@@ -112,7 +113,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitLibdeviceMathCall(
converted_input_types, output_type)
.ValueOrDie();
if (cast_result_to_fp16) {
- result = ir_builder_->CreateFPCast(result, ir_builder_->getHalfTy());
+ result = b_->CreateFPCast(result, b_->getHalfTy());
}
return result;
}
@@ -215,7 +216,7 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
// LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX
// rsqrt.approx instruction.
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
- return ir_builder_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
+ return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
VLOG(10) << "emitting pow as regular call to pow(): " << op->ToString();
@@ -277,6 +278,16 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitFloatUnaryOp(
PrimitiveType output_type = op->shape().element_type();
switch (op->opcode()) {
case HloOpcode::kTanh:
+ // If we don't care much about precision, emit a fast approximation of
+ // tanh.
+ if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
+ // Upcast F16 to F32 if necessary.
+ llvm::Type* type =
+ input_type == F16 ? b_->getFloatTy() : operand_value->getType();
+ llvm::Value* input = b_->CreateFPCast(operand_value, type);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ return b_->CreateFPCast(fast_tanh, operand_value->getType());
+ }
return EmitLibdeviceMathCall("__nv_tanh", {operand_value}, {input_type},
output_type);
default:
@@ -302,32 +313,31 @@ llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(
// Declares the callee if it is not declared already.
llvm::Function* callee = llvm::cast<llvm::Function>(
- ir_builder_->GetInsertBlock()->getModule()->getOrInsertFunction(
+ b_->GetInsertBlock()->getModule()->getOrInsertFunction(
llvm_ir::AsStringRef(callee_name), callee_type));
for (auto attribute : attributes) {
callee->addFnAttr(attribute);
}
- return ir_builder_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
+ return b_->CreateCall(callee, llvm_ir::AsArrayRef(operands));
}
llvm::Value* GpuElementalIrEmitter::EmitThreadId() const {
- llvm::Value* block_id = ir_builder_->CreateIntCast(
+ llvm::Value* block_id = b_->CreateIntCast(
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ctaid_x,
- {}, {}, ir_builder_),
- ir_builder_->getIntNTy(128), /*isSigned=*/true, "block.id");
- llvm::Value* thread_id_in_block = ir_builder_->CreateIntCast(
+ {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "block.id");
+ llvm::Value* thread_id_in_block = b_->CreateIntCast(
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_tid_x,
- {}, {}, ir_builder_),
- ir_builder_->getIntNTy(128), /*isSigned=*/true, "thread.id");
- llvm::Value* threads_per_block = ir_builder_->CreateIntCast(
+ {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "thread.id");
+ llvm::Value* threads_per_block = b_->CreateIntCast(
llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::nvvm_read_ptx_sreg_ntid_x,
- {}, {}, ir_builder_),
- ir_builder_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
- return ir_builder_->CreateNSWAdd(
- ir_builder_->CreateNSWMul(block_id, threads_per_block),
- thread_id_in_block);
+ {}, {}, b_),
+ b_->getIntNTy(128), /*isSigned=*/true, "threads_per_block");
+ return b_->CreateNSWAdd(b_->CreateNSWMul(block_id, threads_per_block),
+ thread_id_in_block);
}
llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
@@ -373,12 +383,12 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accum_ptr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, module_),
- "reduce_window_accum_ptr", ir_builder_);
+ "reduce_window_accum_ptr", b_);
{
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- ir_builder_->CreateStore(init_value, accum_ptr);
+ b_->CreateStore(init_value, accum_ptr);
}
llvm::Type* index_type = index.GetType();
@@ -386,7 +396,7 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
return index.GetConstantWithIndexType(c);
};
- llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type);
+ llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
@@ -395,15 +405,15 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
CHECK_EQ(window_index.size(), index.size());
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder_);
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b_);
IrArray::Index input_index(index_type, index.size());
- llvm::Value* in_bounds = ir_builder_->getInt1(true);
+ llvm::Value* in_bounds = b_->getInt1(true);
for (size_t i = 0; i < index.size(); ++i) {
- llvm::Value* stridden_index = ir_builder_->CreateNSWMul(
+ llvm::Value* stridden_index = b_->CreateNSWMul(
index[i], index_typed_const(window.dimensions(i).stride()));
- input_index[i] = ir_builder_->CreateNSWSub(
- ir_builder_->CreateNSWAdd(stridden_index, window_index[i]),
+ input_index[i] = b_->CreateNSWSub(
+ b_->CreateNSWAdd(stridden_index, window_index[i]),
index_typed_const(window.dimensions(i).padding_low()));
// We must check whether 0 ≤ input_index[i] < bound, as otherwise
@@ -411,16 +421,16 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
// comparison is equivalent to the unsigned comparison
// input_index[i] < bound, as a negative value wraps to a large
// positive value.
- in_bounds = ir_builder_->CreateAnd(
+ in_bounds = b_->CreateAnd(
in_bounds,
- ir_builder_->CreateICmpULT(
+ b_->CreateICmpULT(
input_index[i],
index_typed_const(operand->shape().dimensions(i))));
}
llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
+ SetToFirstInsertPoint(if_data.true_block, b_);
// We are not in pad, so do the computation.
TF_ASSIGN_OR_RETURN(llvm::Value * input_value,
@@ -428,26 +438,26 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
TF_ASSIGN_OR_RETURN(
llvm::Value * accum_value,
compute_nested_(*hlo->to_apply(),
- {ir_builder_->CreateLoad(accum_ptr), input_value}));
- ir_builder_->CreateStore(accum_value, accum_ptr);
+ {b_->CreateLoad(accum_ptr), input_value}));
+ b_->CreateStore(accum_value, accum_ptr);
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder_);
- return ir_builder_->CreateLoad(accum_ptr);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b_);
+ return b_->CreateLoad(accum_ptr);
};
case HloOpcode::kReduce:
return [=, &operand_to_generator](
const IrArray::Index& output_index) -> StatusOr<llvm::Value*> {
const HloInstruction* operand = hlo->operand(0);
llvm::Value* accum_ptr =
- ir_builder()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
+ b()->CreateAlloca(llvm_ir::PrimitiveTypeToIrType(
hlo->shape().element_type(), module_));
llvm::Type* index_type = output_index.GetType();
TF_ASSIGN_OR_RETURN(llvm::Value * init_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index_type)));
- ir_builder()->CreateStore(init_value, accum_ptr);
+ b()->CreateStore(init_value, accum_ptr);
- llvm_ir::ForLoopNest loops(IrName(hlo), ir_builder_, index_type);
+ llvm_ir::ForLoopNest loops(IrName(hlo), b_, index_type);
IrArray::Index input_index = loops.AddLoopsForShapeOnDimensions(
operand->shape(), hlo->dimensions(), "reduction_dim");
if (!ShapeUtil::IsScalar(hlo->shape())) {
@@ -462,18 +472,17 @@ llvm_ir::ElementGenerator GpuElementalIrEmitter::MakeElementGenerator(
CHECK_EQ(output_index.size(), j);
}
- SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), ir_builder());
+ SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), b());
TF_ASSIGN_OR_RETURN(
llvm::Value * input_value,
operand_to_generator.at(hlo->operand(0))(input_index));
TF_ASSIGN_OR_RETURN(
llvm::Value * accum_value,
- compute_nested_(
- *hlo->to_apply(),
- {ir_builder()->CreateLoad(accum_ptr), input_value}));
- ir_builder()->CreateStore(accum_value, accum_ptr);
- SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), ir_builder());
- return ir_builder()->CreateLoad(accum_ptr);
+ compute_nested_(*hlo->to_apply(),
+ {b()->CreateLoad(accum_ptr), input_value}));
+ b()->CreateStore(accum_value, accum_ptr);
+ SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), b());
+ return b()->CreateLoad(accum_ptr);
};
default:
return ElementalIrEmitter::MakeElementGenerator(hlo,