aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc88
1 files changed, 28 insertions, 60 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
index 2e5cc96098..cef5e57b0b 100644
--- a/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/llvm_ir_runtime.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "llvm/IR/Verifier.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "tensorflow/compiler/xla/service/cpu/vector_support_library.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/math_ops.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/logging.h"
@@ -52,46 +53,14 @@ llvm::Function* EmitVectorF32TanhIfNeeded(llvm::Module* module,
llvm::BasicBlock* vector_tanh_body =
llvm::BasicBlock::Create(*context, "body", vector_tanh_function);
- llvm::IRBuilder<> ir_builder(vector_tanh_body);
+ llvm::IRBuilder<> b(vector_tanh_body);
llvm::FastMathFlags fast_math_flags;
- fast_math_flags.setFast();
- ir_builder.setFastMathFlags(fast_math_flags);
-
- VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "tanh_f32");
+ fast_math_flags.setFast(enable_fast_math);
+ b.setFastMathFlags(fast_math_flags);
llvm::Value* input = &*vector_tanh_function->arg_begin();
- CHECK_EQ(input->getType(), vsl.vector_type());
-
- // This implements the same rational interpolant as implemented in Eigen3.
- llvm::Value* input_clamped =
- vsl.Clamp(input, /*low=*/GetIeeeF32(-9.0), /*high=*/GetIeeeF32(9.0));
-
- std::array<float, 7> numerator_coeffs{
- -2.76076847742355e-16f, 2.00018790482477e-13f, -8.60467152213735e-11f,
- 5.12229709037114e-08f, 1.48572235717979e-05f, 6.37261928875436e-04f,
- 4.89352455891786e-03f};
-
- std::array<float, 4> denominator_coeffs{
- 1.19825839466702e-06f, 1.18534705686654e-04f, 2.26843463243900e-03f,
- 4.89352518554385e-03f};
-
- llvm::Value* input_squared = vsl.Mul(input_clamped, input_clamped);
- llvm::Value* numerator = vsl.SplatFloat(GetIeeeF32(numerator_coeffs[0]));
- for (int i = 1; i < numerator_coeffs.size(); i++) {
- numerator =
- vsl.MulAdd(input_squared, numerator, GetIeeeF32(numerator_coeffs[i]));
- }
-
- numerator = vsl.Mul(input_clamped, numerator);
-
- llvm::Value* denominator = vsl.SplatFloat(GetIeeeF32(denominator_coeffs[0]));
- for (int i = 1; i < denominator_coeffs.size(); i++) {
- denominator = vsl.MulAdd(input_squared, denominator,
- GetIeeeF32(denominator_coeffs[i]));
- }
-
- llvm::Value* result = vsl.Div(numerator, denominator);
- ir_builder.CreateRet(result);
+ CHECK_EQ(vector_width, input->getType()->getVectorNumElements());
+ b.CreateRet(llvm_ir::EmitFastTanh(&b, input));
DCHECK(!llvm::verifyFunction(*vector_tanh_function));
return vector_tanh_function;
@@ -113,12 +82,12 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
llvm::BasicBlock* vector_exp_body =
llvm::BasicBlock::Create(*context, "body", vector_exp_function);
- llvm::IRBuilder<> ir_builder(vector_exp_body);
+ llvm::IRBuilder<> b(vector_exp_body);
llvm::FastMathFlags fast_math_flags;
fast_math_flags.setFast();
- ir_builder.setFastMathFlags(fast_math_flags);
+ b.setFastMathFlags(fast_math_flags);
- VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "exp_f32");
+ VectorSupportLibrary vsl(F32, vector_width, &b, "exp_f32");
// This implements the same polynomial approximation as implemented in Eigen3.
@@ -160,21 +129,21 @@ llvm::Function* EmitVectorF32ExpIfNeeded(llvm::Module* module,
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
- ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
+ b.CreateVectorSplat(vector_width, b.getInt32(0x7f));
llvm::Value* vector_constant_23 =
- ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
+ b.CreateVectorSplat(vector_width, b.getInt32(23));
llvm::Type* i32_vector_type =
- llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
+ llvm::VectorType::get(b.getInt32Ty(), vector_width);
// fx is clamped so we don't have to worry about it being out of range for
// i32.
- llvm::Value* emm0 = ir_builder.CreateFPToSI(fx, i32_vector_type);
- emm0 = ir_builder.CreateAdd(emm0, vector_constant_0x7f);
- emm0 = ir_builder.CreateShl(emm0, vector_constant_23);
- llvm::Value* emm0_f32 = ir_builder.CreateBitCast(emm0, vsl.vector_type());
+ llvm::Value* emm0 = b.CreateFPToSI(fx, i32_vector_type);
+ emm0 = b.CreateAdd(emm0, vector_constant_0x7f);
+ emm0 = b.CreateShl(emm0, vector_constant_23);
+ llvm::Value* emm0_f32 = b.CreateBitCast(emm0, vsl.vector_type());
llvm::Value* result = vsl.Max(vsl.Mul(y, emm0_f32), input);
- ir_builder.CreateRet(result);
+ b.CreateRet(result);
DCHECK(!llvm::verifyFunction(*vector_exp_function));
return vector_exp_function;
@@ -196,13 +165,13 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
llvm::BasicBlock* vector_log_body =
llvm::BasicBlock::Create(*context, "body", vector_log_function);
- llvm::IRBuilder<> ir_builder(vector_log_body);
+ llvm::IRBuilder<> b(vector_log_body);
llvm::FastMathFlags fast_math_flags;
fast_math_flags.setFast();
- ir_builder.setFastMathFlags(fast_math_flags);
+ b.setFastMathFlags(fast_math_flags);
llvm::Value* input = &*vector_log_function->arg_begin();
- VectorSupportLibrary vsl(F32, vector_width, &ir_builder, "log_f32");
+ VectorSupportLibrary vsl(F32, vector_width, &b, "log_f32");
const llvm::APFloat half = GetIeeeF32(0.5);
const llvm::APFloat one = GetIeeeF32(1.0);
@@ -238,22 +207,21 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
// VectorSupportLibrary (intentionally) can't juggle more than one type at a
// time so drop down to IRBuilder for this bit.
llvm::Value* vector_constant_0x7f =
- ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(0x7f));
+ b.CreateVectorSplat(vector_width, b.getInt32(0x7f));
llvm::Value* vector_constant_23 =
- ir_builder.CreateVectorSplat(vector_width, ir_builder.getInt32(23));
+ b.CreateVectorSplat(vector_width, b.getInt32(23));
llvm::Type* i32_vector_type =
- llvm::VectorType::get(ir_builder.getInt32Ty(), vector_width);
+ llvm::VectorType::get(b.getInt32Ty(), vector_width);
- llvm::Value* emm0 = ir_builder.CreateLShr(
- ir_builder.CreateBitCast(input, i32_vector_type), vector_constant_23);
+ llvm::Value* emm0 =
+ b.CreateLShr(b.CreateBitCast(input, i32_vector_type), vector_constant_23);
// Keep only the fractional part.
input = vsl.FloatAnd(input, inv_mant_mask);
input = vsl.FloatOr(input, half);
- emm0 = ir_builder.CreateSub(emm0, vector_constant_0x7f);
- llvm::Value* e =
- vsl.Add(one, ir_builder.CreateSIToFP(emm0, vsl.vector_type()));
+ emm0 = b.CreateSub(emm0, vector_constant_0x7f);
+ llvm::Value* e = vsl.Add(one, b.CreateSIToFP(emm0, vsl.vector_type()));
// part2:
// if( x < SQRTHF ) {
@@ -294,7 +262,7 @@ llvm::Function* EmitVectorF32LogIfNeeded(llvm::Module* module,
llvm::Value* or_rhs = vsl.FloatAnd(iszero_mask, minus_inf);
llvm::Value* result = vsl.FloatOr(or_lhs, or_rhs);
- ir_builder.CreateRet(result);
+ b.CreateRet(result);
DCHECK(!llvm::verifyFunction(*vector_log_function));
return vector_log_function;