aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc28
1 files changed, 16 insertions, 12 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
index 69ba91793d..9b6de115ad 100644
--- a/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/gpu/elemental_ir_emitter.cc
@@ -210,11 +210,13 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitPowerOp(
return make_sqrt();
}
- if (hlo_module_config_.debug_options().xla_enable_fast_math() &&
- IsFPLiteralWithValue(rhs, -.5)) {
+ if (IsFPLiteralWithValue(rhs, -.5)) {
VLOG(10) << "emitting pow(A, -.5) as 1/sqrt(A): " << op->ToString();
// LLVM's NVPTX backend knows how to transform 1/sqrt(A) into the NVPTX
// rsqrt.approx instruction.
+ //
+ // TODO(jlebar): Does this happen with fastmath disabled? If not, should
+ // we force-enable it?
TF_ASSIGN_OR_RETURN(auto* sqrt, make_sqrt());
return b_->CreateFDiv(llvm::ConstantFP::get(llvm_ty, 1), sqrt);
}
@@ -274,16 +276,18 @@ StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitAtan2(
StatusOr<llvm::Value*> GpuElementalIrEmitter::EmitTanh(
PrimitiveType prim_type, llvm::Value* value) const {
- // If we don't care much about precision, emit a fast approximation of
- // tanh.
- if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
- // Upcast F16 to F32 if necessary.
- llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
- llvm::Value* input = b_->CreateFPCast(value, type);
- llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
- return b_->CreateFPCast(fast_tanh, value->getType());
- }
- return EmitLibdeviceMathCall("__nv_tanh", {value}, {prim_type}, prim_type);
+ // Emit a fast approximation of tanh instead of calling __nv_tanh.
+ // __nv_tanh is particularly bad because it contains branches, thus
+ // preventing LLVM's load-store vectorizer from working its magic across a
+ // function which contains tanh calls.
+ //
+ // This routine isn't numerically precise, but it's good enough for ML.
+
+ // Upcast F16 to F32 if necessary.
+ llvm::Type* type = prim_type == F16 ? b_->getFloatTy() : value->getType();
+ llvm::Value* input = b_->CreateFPCast(value, type);
+ llvm::Value* fast_tanh = llvm_ir::EmitFastTanh(b_, input);
+ return b_->CreateFPCast(fast_tanh, value->getType());
}
llvm::Value* GpuElementalIrEmitter::EmitDeviceFunctionCall(