aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/elemental_ir_emitter.cc')
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc1149
1 files changed, 552 insertions, 597 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index ce0951bbe1..47ed6162ed 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -61,13 +61,13 @@ int64 GlobalRandomValue() {
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
int64 mantissa_bits,
- llvm::IRBuilder<>* ir_builder) {
+ llvm::IRBuilder<>* b) {
// Integer and float types for casting and constant generation.
llvm::Type* float_type = x->getType();
- llvm::IntegerType* int_type = ir_builder->getInt32Ty();
+ llvm::IntegerType* int_type = b->getInt32Ty();
// Cast the input value to an integer for bitwise manipulation.
- llvm::Value* x_as_int = ir_builder->CreateBitCast(x, int_type);
+ llvm::Value* x_as_int = b->CreateBitCast(x, int_type);
if (mantissa_bits < 23) {
// Last remaining mantissa bit.
@@ -77,22 +77,22 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
// equal to a base value of 0111... plus one bit if the last remaining
// mantissa bit is 1.
const uint32_t base_rounding_bias = (last_mantissa_bit_mask >> 1) - 1;
- llvm::Value* x_last_mantissa_bit = ir_builder->CreateLShr(
- ir_builder->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
+ llvm::Value* x_last_mantissa_bit = b->CreateLShr(
+ b->CreateAnd(x_as_int,
+ llvm::ConstantInt::get(int_type, last_mantissa_bit_mask)),
(23 - mantissa_bits));
- llvm::Value* x_rounding_bias = ir_builder->CreateAdd(
- x_last_mantissa_bit,
- llvm::ConstantInt::get(int_type, base_rounding_bias));
+ llvm::Value* x_rounding_bias =
+ b->CreateAdd(x_last_mantissa_bit,
+ llvm::ConstantInt::get(int_type, base_rounding_bias));
// Add rounding bias, and mask out truncated bits. Note that the case
// where adding the rounding bias overflows into the exponent bits is
// correct; the non-masked mantissa bits will all be zero, and the
// exponent will be incremented by one.
const uint32_t truncation_mask = ~(last_mantissa_bit_mask - 1);
- x_as_int = ir_builder->CreateAdd(x_as_int, x_rounding_bias);
- x_as_int = ir_builder->CreateAnd(
- x_as_int, llvm::ConstantInt::get(int_type, truncation_mask));
+ x_as_int = b->CreateAdd(x_as_int, x_rounding_bias);
+ x_as_int = b->CreateAnd(x_as_int,
+ llvm::ConstantInt::get(int_type, truncation_mask));
}
if (exponent_bits < 8) {
@@ -120,29 +120,29 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
f32_exponent_bias - reduced_exponent_bias;
// Do we overflow or underflow?
- llvm::Value* x_exponent = ir_builder->CreateAnd(
+ llvm::Value* x_exponent = b->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
- llvm::Value* x_overflows = ir_builder->CreateICmpUGT(
+ llvm::Value* x_overflows = b->CreateICmpUGT(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_max_exponent << 23));
- llvm::Value* x_underflows = ir_builder->CreateICmpULE(
+ llvm::Value* x_underflows = b->CreateICmpULE(
x_exponent,
llvm::ConstantInt::get(int_type, reduced_min_exponent << 23));
// Compute appropriately-signed values of zero and infinity.
- llvm::Value* x_signed_zero = ir_builder->CreateAnd(
+ llvm::Value* x_signed_zero = b->CreateAnd(
x_as_int, llvm::ConstantInt::get(int_type, f32_sign_bit_mask));
- llvm::Value* x_signed_inf = ir_builder->CreateOr(
+ llvm::Value* x_signed_inf = b->CreateOr(
x_signed_zero, llvm::ConstantInt::get(int_type, f32_exp_bits_mask));
// Force to zero or infinity if overflow or underflow. (Note that this
// truncates all denormal values to zero, rather than rounding them.)
- x_as_int = ir_builder->CreateSelect(x_overflows, x_signed_inf, x_as_int);
- x_as_int = ir_builder->CreateSelect(x_underflows, x_signed_zero, x_as_int);
+ x_as_int = b->CreateSelect(x_overflows, x_signed_inf, x_as_int);
+ x_as_int = b->CreateSelect(x_underflows, x_signed_zero, x_as_int);
}
// Cast the result back to a floating-point type.
- llvm::Value* result = ir_builder->CreateBitCast(x_as_int, float_type);
+ llvm::Value* result = b->CreateBitCast(x_as_int, float_type);
// Correct result for NaN inputs.
//
@@ -154,53 +154,49 @@ llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
//
// If the fast-math flags are set to assume no NaNs, the comparison is likely
// to be optimized away, so there's no point in even emitting it.
- if (!ir_builder->getFastMathFlags().noNaNs()) {
- llvm::Value* x_is_nan = ir_builder->CreateFCmpUNO(x, x);
+ if (!b->getFastMathFlags().noNaNs()) {
+ llvm::Value* x_is_nan = b->CreateFCmpUNO(x, x);
if (mantissa_bits > 0) {
- result = ir_builder->CreateSelect(x_is_nan, x, result);
+ result = b->CreateSelect(x_is_nan, x, result);
} else {
- result = ir_builder->CreateSelect(
+ result = b->CreateSelect(
x_is_nan, llvm::ConstantFP::getInfinity(float_type), result);
}
}
return result;
}
-llvm::Value* EmitF32ToBF16(llvm::Value* f32_value,
- llvm::IRBuilder<>* ir_builder) {
+llvm::Value* EmitF32ToBF16(llvm::Value* f32_value, llvm::IRBuilder<>* b) {
auto reduced_precision = EmitReducePrecisionFloat(
f32_value,
/*exponent_bits=*/primitive_util::kBFloat16ExponentBits,
- /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, ir_builder);
- auto as_int32 =
- ir_builder->CreateBitCast(reduced_precision, ir_builder->getInt32Ty());
- auto shifted = ir_builder->CreateLShr(as_int32, 16);
- auto truncated = ir_builder->CreateTrunc(shifted, ir_builder->getInt16Ty());
- return ir_builder->CreateBitCast(truncated, ir_builder->getInt16Ty());
+ /*mantissa_bits=*/primitive_util::kBFloat16MantissaBits, b);
+ auto as_int32 = b->CreateBitCast(reduced_precision, b->getInt32Ty());
+ auto shifted = b->CreateLShr(as_int32, 16);
+ auto truncated = b->CreateTrunc(shifted, b->getInt16Ty());
+ return b->CreateBitCast(truncated, b->getInt16Ty());
}
-llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value,
- llvm::IRBuilder<>* ir_builder) {
- auto as_int16 =
- ir_builder->CreateBitCast(bf16_value, ir_builder->getInt16Ty());
- auto as_int32 = ir_builder->CreateZExt(as_int16, ir_builder->getInt32Ty());
- auto shifted = ir_builder->CreateShl(as_int32, 16);
- return ir_builder->CreateBitCast(shifted, ir_builder->getFloatTy());
+llvm::Value* EmitBF16ToF32(llvm::Value* bf16_value, llvm::IRBuilder<>* b) {
+ auto as_int16 = b->CreateBitCast(bf16_value, b->getInt16Ty());
+ auto as_int32 = b->CreateZExt(as_int16, b->getInt32Ty());
+ auto shifted = b->CreateShl(as_int32, 16);
+ return b->CreateBitCast(shifted, b->getFloatTy());
}
llvm::Value* EmitIntegralToFloating(llvm::Value* integer_value,
PrimitiveType from_type,
PrimitiveType to_type, llvm::Module* module,
- llvm::IRBuilder<>* ir_builder) {
+ llvm::IRBuilder<>* b) {
if (primitive_util::IsSignedIntegralType(from_type)) {
- return ir_builder->CreateSIToFP(
- integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ return b->CreateSIToFP(integer_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module));
} else {
CHECK(primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED);
- return ir_builder->CreateUIToFP(
- integer_value, llvm_ir::PrimitiveTypeToIrType(to_type, module));
+ return b->CreateUIToFP(integer_value,
+ llvm_ir::PrimitiveTypeToIrType(to_type, module));
}
}
@@ -226,39 +222,43 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
- CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED);
+ CHECK(primitive_util::IsIntegralType(from_type) || from_type == PRED)
+ << from_type;
if (from_type == to_type) {
return operand_value;
}
+ if (to_type == PRED) {
+ return b_->CreateZExt(
+ b_->CreateICmpNE(operand_value, llvm::ConstantInt::get(
+ operand_value->getType(), 0)),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ }
if (primitive_util::IsIntegralType(to_type)) {
- return ir_builder_->CreateIntCast(
+ return b_->CreateIntCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_),
primitive_util::IsSignedIntegralType(from_type));
}
if (primitive_util::IsFloatingPointType(to_type)) {
if (to_type == BF16) {
- return EmitF32ToBF16(
- EmitIntegralToFloating(operand_value, from_type, F32, module_,
- ir_builder_),
- ir_builder_);
+ return EmitF32ToBF16(EmitIntegralToFloating(operand_value, from_type,
+ F32, module_, b_),
+ b_);
}
return EmitIntegralToFloating(operand_value, from_type, to_type,
- module_, ir_builder_);
+ module_, b_);
}
if (primitive_util::IsComplexType(to_type)) {
auto to_ir_component_type = llvm_ir::PrimitiveTypeToIrType(
primitive_util::ComplexComponentType(to_type), module_);
if (primitive_util::IsSignedIntegralType(from_type)) {
return EmitComposeComplex(
- op,
- ir_builder_->CreateSIToFP(operand_value, to_ir_component_type),
+ op, b_->CreateSIToFP(operand_value, to_ir_component_type),
nullptr);
}
if (primitive_util::IsUnsignedIntegralType(from_type) ||
from_type == PRED) {
return EmitComposeComplex(
- op,
- ir_builder_->CreateUIToFP(operand_value, to_ir_component_type),
+ op, b_->CreateUIToFP(operand_value, to_ir_component_type),
nullptr);
}
}
@@ -275,7 +275,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return ir_builder_->CreateBitCast(
+ return b_->CreateBitCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
@@ -293,18 +293,18 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = ir_builder_->CreateICmpSGE(operand_value, zero);
- return ir_builder_->CreateSelect(cmp, operand_value,
- ir_builder_->CreateNeg(operand_value));
+ auto cmp = b_->CreateICmpSGE(operand_value, zero);
+ return b_->CreateSelect(cmp, operand_value,
+ b_->CreateNeg(operand_value));
} else {
return operand_value;
}
}
case HloOpcode::kClz: {
- auto is_zero_undef = ir_builder_->getFalse();
- return llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::ctlz, {operand_value, is_zero_undef},
- {operand_value->getType()}, ir_builder_);
+ auto is_zero_undef = b_->getFalse();
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ctlz,
+ {operand_value, is_zero_undef},
+ {operand_value->getType()}, b_);
}
case HloOpcode::kSign: {
bool is_signed =
@@ -312,31 +312,28 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerUnaryOp(
auto type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
auto zero = llvm::ConstantInt::get(type, 0);
- auto cmp = ir_builder_->CreateICmpEQ(operand_value, zero);
+ auto cmp = b_->CreateICmpEQ(operand_value, zero);
if (is_signed) {
- auto ashr = ir_builder_->CreateAShr(operand_value,
- type->getIntegerBitWidth() - 1);
- return ir_builder_->CreateSelect(cmp, zero,
- ir_builder_->CreateOr(ashr, 1));
+ auto ashr =
+ b_->CreateAShr(operand_value, type->getIntegerBitWidth() - 1);
+ return b_->CreateSelect(cmp, zero, b_->CreateOr(ashr, 1));
} else {
- return ir_builder_->CreateSelect(cmp, zero,
- llvm::ConstantInt::get(type, 1));
+ return b_->CreateSelect(cmp, zero, llvm::ConstantInt::get(type, 1));
}
}
case HloOpcode::kNegate:
- return ir_builder_->CreateNeg(operand_value);
+ return b_->CreateNeg(operand_value);
case HloOpcode::kNot: {
auto type = op->shape().element_type();
if (type == PRED) {
// It is not sufficient to just call CreateNot() here because a PRED
// is represented as an i8 and the truth value is stored only in the
// bottom bit.
- return ir_builder_->CreateZExt(
- ir_builder_->CreateNot(ir_builder_->CreateTrunc(
- operand_value, ir_builder_->getInt1Ty())),
+ return b_->CreateZExt(
+ b_->CreateNot(b_->CreateTrunc(operand_value, b_->getInt1Ty())),
llvm_ir::PrimitiveTypeToIrType(PRED, module_));
} else if (primitive_util::IsIntegralType(type)) {
- return ir_builder_->CreateNot(operand_value);
+ return b_->CreateNot(operand_value);
}
return Unimplemented("unary op Not is not defined for type '%d'", type);
}
@@ -352,7 +349,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
PrimitiveType to_type = op->shape().element_type();
- CHECK(primitive_util::IsFloatingPointType(from_type));
+ CHECK(primitive_util::IsFloatingPointType(from_type)) << from_type;
if (from_type == to_type) {
return operand_value;
}
@@ -364,32 +361,38 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
return EmitComposeComplex(
op,
- ir_builder_->CreateFPCast(
- operand_value,
- llvm_ir::PrimitiveTypeToIrType(to_component_type, module_)),
+ b_->CreateFPCast(operand_value, llvm_ir::PrimitiveTypeToIrType(
+ to_component_type, module_)),
nullptr);
}
if (from_type == BF16) {
TF_RET_CHECK(to_type != BF16);
- operand_value = EmitBF16ToF32(operand_value, ir_builder_);
+ operand_value = EmitBF16ToF32(operand_value, b_);
from_type = F32;
if (from_type == to_type) {
return operand_value;
}
}
if (from_type == F32 && to_type == BF16) {
- return EmitF32ToBF16(operand_value, ir_builder_);
+ return EmitF32ToBF16(operand_value, b_);
+ }
+ if (to_type == PRED) {
+ return b_->CreateZExt(
+ b_->CreateFCmpUNE(
+ operand_value,
+ llvm::ConstantFP::get(operand_value->getType(), 0.0)),
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
if (primitive_util::IsFloatingPointType(to_type)) {
- return ir_builder_->CreateFPCast(
+ return b_->CreateFPCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsSignedIntegralType(to_type)) {
- return ir_builder_->CreateFPToSI(
+ return b_->CreateFPToSI(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
if (primitive_util::IsUnsignedIntegralType(to_type)) {
- return ir_builder_->CreateFPToUI(
+ return b_->CreateFPToUI(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return Unimplemented("unhandled conversion operation: %s => %s",
@@ -405,7 +408,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
}
if (primitive_util::BitWidth(from_type) ==
primitive_util::BitWidth(to_type)) {
- return ir_builder_->CreateBitCast(
+ return b_->CreateBitCast(
operand_value, llvm_ir::PrimitiveTypeToIrType(to_type, module_));
}
return InvalidArgument(
@@ -429,45 +432,49 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatUnaryOp(
case HloOpcode::kSin:
return EmitSin(op->shape().element_type(), operand_value);
case HloOpcode::kFloor:
- return llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::floor, {operand_value}, {operand_value->getType()},
- ir_builder_);
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::floor,
+ {operand_value},
+ {operand_value->getType()}, b_);
case HloOpcode::kCeil:
- return llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::ceil, {operand_value}, {operand_value->getType()},
- ir_builder_);
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::ceil,
+ {operand_value},
+ {operand_value->getType()}, b_);
case HloOpcode::kAbs:
- return llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::fabs, {operand_value}, {operand_value->getType()},
- ir_builder_);
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs,
+ {operand_value},
+ {operand_value->getType()}, b_);
case HloOpcode::kRoundNearestAfz:
- return llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::round, {operand_value}, {operand_value->getType()},
- ir_builder_);
+ return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::round,
+ {operand_value},
+ {operand_value->getType()}, b_);
case HloOpcode::kSign: {
// TODO(b/32151903): Ensure consistent sign behavior for -0.0.
auto type = operand_value->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = ir_builder_->CreateFCmpOEQ(operand_value, zero);
- auto olt = ir_builder_->CreateFCmpOLT(operand_value, zero);
- return ir_builder_->CreateSelect(
+ auto oeq = b_->CreateFCmpOEQ(operand_value, zero);
+ auto olt = b_->CreateFCmpOLT(operand_value, zero);
+ return b_->CreateSelect(
oeq, zero,
- ir_builder_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
- llvm::ConstantFP::get(type, 1.0)));
+ b_->CreateSelect(olt, llvm::ConstantFP::get(type, -1.0),
+ llvm::ConstantFP::get(type, 1.0)));
}
case HloOpcode::kIsFinite: {
// abs(x) o!= inf, this works because the comparison returns false if
// either operand is NaN.
auto type = operand_value->getType();
auto abs_value = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::fabs, {operand_value}, {type}, ir_builder_);
+ llvm::Intrinsic::fabs, {operand_value}, {type}, b_);
auto infinity = llvm::ConstantFP::getInfinity(type);
- auto not_infinite = ir_builder_->CreateFCmpONE(abs_value, infinity);
- return ir_builder_->CreateZExt(
- not_infinite, llvm_ir::PrimitiveTypeToIrType(PRED, module_));
+ auto not_infinite = b_->CreateFCmpONE(abs_value, infinity);
+ return b_->CreateZExt(not_infinite,
+ llvm_ir::PrimitiveTypeToIrType(PRED, module_));
}
case HloOpcode::kNegate:
- return ir_builder_->CreateFNeg(operand_value);
+ return b_->CreateFNeg(operand_value);
+ case HloOpcode::kReal:
+ return operand_value;
+ case HloOpcode::kImag:
+ return llvm::ConstantFP::get(operand_value->getType(), 0.0);
default:
return Unimplemented("unary floating-point op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -487,13 +494,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto a = EmitExtractReal(operand_value);
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
- auto sum_sq = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
- ir_builder_->CreateFMul(b, b));
+ auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(
- op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
+ return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
+ angle);
}
case HloOpcode::kLog1p: {
// log1p(a+bi) = .5*log((a+1)^2+b^2) + i*atan2(b, a + 1)
@@ -501,15 +507,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
llvm::Type* llvm_ty = a->getType();
auto one = llvm::ConstantFP::get(llvm_ty, 1.0);
- auto a_plus_one = ir_builder_->CreateFAdd(a, one);
- auto sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(a_plus_one, a_plus_one),
- ir_builder_->CreateFMul(b, b));
+ auto a_plus_one = b_->CreateFAdd(a, one);
+ auto sum_sq = b_->CreateFAdd(b_->CreateFMul(a_plus_one, a_plus_one),
+ b_->CreateFMul(b, b));
TF_ASSIGN_OR_RETURN(auto log_sum_sq, EmitLog(component_type, sum_sq));
TF_ASSIGN_OR_RETURN(auto angle, EmitAtan2(component_type, b, a_plus_one));
auto one_half = llvm::ConstantFP::get(llvm_ty, 0.5);
- return EmitComposeComplex(
- op, ir_builder_->CreateFMul(one_half, log_sum_sq), angle);
+ return EmitComposeComplex(op, b_->CreateFMul(one_half, log_sum_sq),
+ angle);
}
case HloOpcode::kConvert: {
PrimitiveType from_type = op->operand(0)->shape().element_type();
@@ -523,12 +528,11 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
primitive_util::ComplexComponentType(to_type);
auto to_ir_component_type =
llvm_ir::PrimitiveTypeToIrType(to_component_type, module_);
- return EmitComposeComplex(
- op,
- ir_builder_->CreateFPCast(EmitExtractReal(operand_value),
- to_ir_component_type),
- ir_builder_->CreateFPCast(EmitExtractImag(operand_value),
- to_ir_component_type));
+ return EmitComposeComplex(op,
+ b_->CreateFPCast(EmitExtractReal(operand_value),
+ to_ir_component_type),
+ b_->CreateFPCast(EmitExtractImag(operand_value),
+ to_ir_component_type));
}
case HloOpcode::kExp: {
// e^(a+bi) = e^a*(cos(b)+sin(b)i)
@@ -538,8 +542,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto cos_b, EmitCos(component_type, EmitExtractImag(operand_value)));
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
- return EmitComposeComplex(op, ir_builder_->CreateFMul(exp_a, cos_b),
- ir_builder_->CreateFMul(exp_a, sin_b));
+ return EmitComposeComplex(op, b_->CreateFMul(exp_a, cos_b),
+ b_->CreateFMul(exp_a, sin_b));
}
case HloOpcode::kExpm1: {
// e^(a+bi)-1 = (e^a*cos(b)-1)+e^a*sin(b)i
@@ -550,9 +554,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(
auto sin_b, EmitSin(component_type, EmitExtractImag(operand_value)));
auto one = llvm::ConstantFP::get(exp_a->getType(), 1.0);
- auto real_result =
- ir_builder_->CreateFSub(ir_builder_->CreateFMul(exp_a, cos_b), one);
- auto imag_result = ir_builder_->CreateFMul(exp_a, sin_b);
+ auto real_result = b_->CreateFSub(b_->CreateFMul(exp_a, cos_b), one);
+ auto imag_result = b_->CreateFMul(exp_a, sin_b);
return EmitComposeComplex(op, real_result, imag_result);
}
case HloOpcode::kCos: {
@@ -567,18 +570,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b =
- ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
- ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(
- op,
- ir_builder_->CreateFMul(
- cos_a, ir_builder_->CreateFAdd(half_exp_neg_b, half_exp_b)),
- ir_builder_->CreateFMul(
- sin_a, ir_builder_->CreateFSub(half_exp_neg_b, half_exp_b)));
+ op, b_->CreateFMul(cos_a, b_->CreateFAdd(half_exp_neg_b, half_exp_b)),
+ b_->CreateFMul(sin_a, b_->CreateFSub(half_exp_neg_b, half_exp_b)));
}
case HloOpcode::kSin: {
// sin(z) = .5i(e^(-iz) - e^(iz))
@@ -594,18 +593,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
auto b = EmitExtractImag(operand_value);
auto type = a->getType();
TF_ASSIGN_OR_RETURN(auto exp_b, EmitExp(component_type, b));
- auto half_exp_b =
- ir_builder_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
+ auto half_exp_b = b_->CreateFMul(llvm::ConstantFP::get(type, 0.5), exp_b);
auto half_exp_neg_b =
- ir_builder_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
+ b_->CreateFDiv(llvm::ConstantFP::get(type, 0.5), exp_b);
TF_ASSIGN_OR_RETURN(auto cos_a, EmitCos(component_type, a));
TF_ASSIGN_OR_RETURN(auto sin_a, EmitSin(component_type, a));
return EmitComposeComplex(
- op,
- ir_builder_->CreateFMul(
- sin_a, ir_builder_->CreateFAdd(half_exp_b, half_exp_neg_b)),
- ir_builder_->CreateFMul(
- cos_a, ir_builder_->CreateFSub(half_exp_b, half_exp_neg_b)));
+ op, b_->CreateFMul(sin_a, b_->CreateFAdd(half_exp_b, half_exp_neg_b)),
+ b_->CreateFMul(cos_a, b_->CreateFSub(half_exp_b, half_exp_neg_b)));
}
case HloOpcode::kTanh: {
/*
@@ -633,64 +628,61 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexUnaryOp(
TF_ASSIGN_OR_RETURN(auto exp_a, EmitExp(component_type, a));
TF_ASSIGN_OR_RETURN(auto cos_b, EmitCos(component_type, b));
TF_ASSIGN_OR_RETURN(auto sin_b, EmitSin(component_type, b));
- auto exp_neg_a = ir_builder_->CreateFDiv(
- llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
- auto exp_2a_minus_exp_neg_2a = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(exp_a, exp_a),
- ir_builder_->CreateFMul(exp_neg_a, exp_neg_a));
- auto cos_b_sq = ir_builder_->CreateFMul(cos_b, cos_b);
- auto sin_b_sq = ir_builder_->CreateFMul(sin_b, sin_b);
- auto real_num = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
- ir_builder_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
- auto cos_b_sin_b = ir_builder_->CreateFMul(cos_b, sin_b);
- auto exp_a_plus_exp_neg_a = ir_builder_->CreateFAdd(exp_a, exp_neg_a);
+ auto exp_neg_a =
+ b_->CreateFDiv(llvm::ConstantFP::get(exp_a->getType(), 1), exp_a);
+ auto exp_2a_minus_exp_neg_2a = b_->CreateFSub(
+ b_->CreateFMul(exp_a, exp_a), b_->CreateFMul(exp_neg_a, exp_neg_a));
+ auto cos_b_sq = b_->CreateFMul(cos_b, cos_b);
+ auto sin_b_sq = b_->CreateFMul(sin_b, sin_b);
+ auto real_num =
+ b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_2a_minus_exp_neg_2a),
+ b_->CreateFMul(sin_b_sq, exp_2a_minus_exp_neg_2a));
+ auto cos_b_sin_b = b_->CreateFMul(cos_b, sin_b);
+ auto exp_a_plus_exp_neg_a = b_->CreateFAdd(exp_a, exp_neg_a);
auto exp_a_plus_exp_neg_a_sq =
- ir_builder_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
- auto exp_a_minus_exp_neg_a = ir_builder_->CreateFSub(exp_a, exp_neg_a);
+ b_->CreateFMul(exp_a_plus_exp_neg_a, exp_a_plus_exp_neg_a);
+ auto exp_a_minus_exp_neg_a = b_->CreateFSub(exp_a, exp_neg_a);
auto exp_a_minus_exp_neg_a_sq =
- ir_builder_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
- auto imag_num = ir_builder_->CreateFMul(
- cos_b_sin_b, ir_builder_->CreateFSub(exp_a_plus_exp_neg_a_sq,
- exp_a_minus_exp_neg_a_sq));
- auto denom = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
- ir_builder_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
- return EmitComposeComplex(op, ir_builder_->CreateFDiv(real_num, denom),
- ir_builder_->CreateFDiv(imag_num, denom));
+ b_->CreateFMul(exp_a_minus_exp_neg_a, exp_a_minus_exp_neg_a);
+ auto imag_num = b_->CreateFMul(
+ cos_b_sin_b,
+ b_->CreateFSub(exp_a_plus_exp_neg_a_sq, exp_a_minus_exp_neg_a_sq));
+ auto denom =
+ b_->CreateFAdd(b_->CreateFMul(cos_b_sq, exp_a_plus_exp_neg_a_sq),
+ b_->CreateFMul(sin_b_sq, exp_a_minus_exp_neg_a_sq));
+ return EmitComposeComplex(op, b_->CreateFDiv(real_num, denom),
+ b_->CreateFDiv(imag_num, denom));
}
case HloOpcode::kAbs: {
- auto sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- ir_builder_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq =
+ b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
+ EmitExtractReal(operand_value)),
+ b_->CreateFMul(EmitExtractImag(operand_value),
+ EmitExtractImag(operand_value)));
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sqrt, {sum_sq},
- {sum_sq->getType()}, ir_builder_);
+ {sum_sq->getType()}, b_);
}
case HloOpcode::kSign: { // Sign(c) = c / |c|
- auto sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(operand_value),
- EmitExtractReal(operand_value)),
- ir_builder_->CreateFMul(EmitExtractImag(operand_value),
- EmitExtractImag(operand_value)));
+ auto sum_sq =
+ b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(operand_value),
+ EmitExtractReal(operand_value)),
+ b_->CreateFMul(EmitExtractImag(operand_value),
+ EmitExtractImag(operand_value)));
auto cplx_abs = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, ir_builder_);
+ llvm::Intrinsic::sqrt, {sum_sq}, {sum_sq->getType()}, b_);
auto type = cplx_abs->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = ir_builder_->CreateFCmpOEQ(cplx_abs, zero);
- return ir_builder_->CreateSelect(
+ auto oeq = b_->CreateFCmpOEQ(cplx_abs, zero);
+ return b_->CreateSelect(
oeq, EmitComposeComplex(op, zero, zero),
EmitComposeComplex(
- op,
- ir_builder_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
- ir_builder_->CreateFDiv(EmitExtractImag(operand_value),
- cplx_abs)));
+ op, b_->CreateFDiv(EmitExtractReal(operand_value), cplx_abs),
+ b_->CreateFDiv(EmitExtractImag(operand_value), cplx_abs)));
}
case HloOpcode::kNegate:
- return EmitComposeComplex(
- op, ir_builder_->CreateFNeg(EmitExtractReal(operand_value)),
- ir_builder_->CreateFNeg(EmitExtractImag(operand_value)));
+ return EmitComposeComplex(op,
+ b_->CreateFNeg(EmitExtractReal(operand_value)),
+ b_->CreateFNeg(EmitExtractImag(operand_value)));
case HloOpcode::kReal:
return EmitExtractReal(operand_value);
case HloOpcode::kImag:
@@ -724,15 +716,15 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
case HloOpcode::kComplex:
return EmitComposeComplex(op, lhs_value, rhs_value);
case HloOpcode::kAdd:
- return ir_builder_->CreateFAdd(lhs_value, rhs_value);
+ return b_->CreateFAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return ir_builder_->CreateFSub(lhs_value, rhs_value);
+ return b_->CreateFSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return ir_builder_->CreateFMul(lhs_value, rhs_value);
+ return b_->CreateFMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return ir_builder_->CreateFDiv(lhs_value, rhs_value);
+ return b_->CreateFDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
- return ir_builder_->CreateFRem(lhs_value, rhs_value);
+ return b_->CreateFRem(lhs_value, rhs_value);
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
// comparisons always return false when one of the operands is NaN, whereas
// unordered comparisons return true.
@@ -742,22 +734,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitFloatBinaryOp(
// matches C++'s semantics.
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLT, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGT, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OLE, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OGE, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kMaximum:
return EmitFloatMax(lhs_value, rhs_value);
@@ -778,64 +770,56 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
llvm::Value* rhs_value) const {
switch (op->opcode()) {
case HloOpcode::kAdd:
- return EmitComposeComplex(
- op,
- ir_builder_->CreateFAdd(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFAdd(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(op,
+ b_->CreateFAdd(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFAdd(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
case HloOpcode::kSubtract:
- return EmitComposeComplex(
- op,
- ir_builder_->CreateFSub(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFSub(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
+ return EmitComposeComplex(op,
+ b_->CreateFSub(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFSub(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value)));
case HloOpcode::kMultiply:
return EmitComposeComplex(
op,
- ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
- ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value))));
+ b_->CreateFSub(b_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
+ b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value))));
case HloOpcode::kDivide: {
// (a+bi) / (c+di) = ((a+bi)(c-di)) / ((c+di)(c-di))
// = ((ac + bd) + (bc - ad)i) / (c^2 + d^2)
- auto rhs_sum_sq = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(rhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(rhs_value),
- EmitExtractImag(rhs_value)));
+ auto rhs_sum_sq =
+ b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(rhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(rhs_value),
+ EmitExtractImag(rhs_value)));
auto type = rhs_sum_sq->getType();
auto zero = llvm::ConstantFP::get(type, 0.0);
- auto oeq = ir_builder_->CreateFCmpOEQ(rhs_sum_sq, zero);
- auto real_inf_or_nan =
- ir_builder_->CreateFDiv(EmitExtractReal(lhs_value), zero);
- auto imag_inf_or_nan =
- ir_builder_->CreateFDiv(EmitExtractImag(lhs_value), zero);
- return ir_builder_->CreateSelect(
+ auto oeq = b_->CreateFCmpOEQ(rhs_sum_sq, zero);
+ auto real_inf_or_nan = b_->CreateFDiv(EmitExtractReal(lhs_value), zero);
+ auto imag_inf_or_nan = b_->CreateFDiv(EmitExtractImag(lhs_value), zero);
+ return b_->CreateSelect(
oeq, EmitComposeComplex(op, real_inf_or_nan, imag_inf_or_nan),
EmitComposeComplex(
op,
- ir_builder_->CreateFDiv(
- ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value))),
+ b_->CreateFDiv(
+ b_->CreateFAdd(b_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractImag(rhs_value))),
rhs_sum_sq),
- ir_builder_->CreateFDiv(
- ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value))),
+ b_->CreateFDiv(
+ b_->CreateFSub(b_->CreateFMul(EmitExtractImag(lhs_value),
+ EmitExtractReal(rhs_value)),
+ b_->CreateFMul(EmitExtractReal(lhs_value),
+ EmitExtractImag(rhs_value))),
rhs_sum_sq)));
}
// LLVM comparisons can be "unordered" (U) or "ordered" (O) -- ordered
@@ -846,21 +830,21 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
// unordered comparison. This makes x != y equivalent to !(x == y), and
// matches C++'s semantics.
case HloOpcode::kEq:
- return ir_builder_->CreateAnd(
+ return b_->CreateAnd(
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), ir_builder_),
+ EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_OEQ,
EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), ir_builder_));
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kNe:
- return ir_builder_->CreateOr(
+ return b_->CreateOr(
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value), ir_builder_),
+ EmitExtractReal(rhs_value), b_),
llvm_ir::EmitComparison(llvm::CmpInst::FCMP_UNE,
EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value), ir_builder_));
+ EmitExtractImag(rhs_value), b_));
case HloOpcode::kPower: {
// (a+bi)^(c+di) =
@@ -872,29 +856,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
auto b = EmitExtractImag(lhs_value);
auto c = EmitExtractReal(rhs_value);
auto d = EmitExtractImag(rhs_value);
- auto aa_p_bb = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(a, a),
- ir_builder_->CreateFMul(b, b));
+ auto aa_p_bb = b_->CreateFAdd(b_->CreateFMul(a, a), b_->CreateFMul(b, b));
auto one_half = llvm::ConstantFP::get(a->getType(), 0.5);
- auto half_c = ir_builder_->CreateFMul(one_half, c);
+ auto half_c = b_->CreateFMul(one_half, c);
TF_ASSIGN_OR_RETURN(auto aa_p_bb_to_half_c,
EmitPow(component_type, aa_p_bb, half_c));
- auto neg_d = ir_builder_->CreateFNeg(d);
+ auto neg_d = b_->CreateFNeg(d);
TF_ASSIGN_OR_RETURN(auto arg_lhs, EmitAtan2(component_type, b, a));
- auto neg_d_arg_lhs = ir_builder_->CreateFMul(neg_d, arg_lhs);
+ auto neg_d_arg_lhs = b_->CreateFMul(neg_d, arg_lhs);
TF_ASSIGN_OR_RETURN(auto e_to_neg_d_arg_lhs,
EmitExp(component_type, neg_d_arg_lhs));
- auto coeff =
- ir_builder_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
+ auto coeff = b_->CreateFMul(aa_p_bb_to_half_c, e_to_neg_d_arg_lhs);
TF_ASSIGN_OR_RETURN(auto ln_aa_p_bb, EmitLog(component_type, aa_p_bb));
- auto half_d = ir_builder_->CreateFMul(one_half, d);
- auto q =
- ir_builder_->CreateFAdd(ir_builder_->CreateFMul(c, arg_lhs),
- ir_builder_->CreateFMul(half_d, ln_aa_p_bb));
+ auto half_d = b_->CreateFMul(one_half, d);
+ auto q = b_->CreateFAdd(b_->CreateFMul(c, arg_lhs),
+ b_->CreateFMul(half_d, ln_aa_p_bb));
TF_ASSIGN_OR_RETURN(auto cos_q, EmitCos(component_type, q));
TF_ASSIGN_OR_RETURN(auto sin_q, EmitSin(component_type, q));
- return EmitComposeComplex(op, ir_builder_->CreateFMul(coeff, cos_q),
- ir_builder_->CreateFMul(coeff, sin_q));
+ return EmitComposeComplex(op, b_->CreateFMul(coeff, cos_q),
+ b_->CreateFMul(coeff, sin_q));
}
default:
return Unimplemented("binary complex op '%s'",
@@ -904,12 +885,12 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitComplexBinaryOp(
llvm::Value* ElementalIrEmitter::EmitFloatMax(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
- return llvm_ir::EmitFloatMax(lhs_value, rhs_value, ir_builder_);
+ return llvm_ir::EmitFloatMax(lhs_value, rhs_value, b_);
}
llvm::Value* ElementalIrEmitter::EmitFloatMin(llvm::Value* lhs_value,
llvm::Value* rhs_value) const {
- return llvm_ir::EmitFloatMin(lhs_value, rhs_value, ir_builder_);
+ return llvm_ir::EmitFloatMin(lhs_value, rhs_value, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
@@ -921,15 +902,14 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
"type F32.");
}
auto getFloat = [&](const float f) {
- return llvm::ConstantFP::get(ir_builder_->getFloatTy(), f);
+ return llvm::ConstantFP::get(b_->getFloatTy(), f);
};
auto multiply_add = [&](tensorflow::gtl::ArraySlice<float> coefficients,
llvm::Value* w) {
llvm::Value* p = getFloat(coefficients.front());
coefficients.pop_front();
for (float coefficient : coefficients) {
- p = ir_builder_->CreateFAdd(ir_builder_->CreateFMul(p, w),
- getFloat(coefficient));
+ p = b_->CreateFAdd(b_->CreateFMul(p, w), getFloat(coefficient));
}
return p;
};
@@ -947,50 +927,48 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfInv(PrimitiveType prim_type,
// }
// return p*x
llvm::Function* logf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::log, {ir_builder_->getFloatTy()});
+ module_, llvm::Intrinsic::log, {b_->getFloatTy()});
- llvm::Value* w = ir_builder_->CreateFNeg(ir_builder_->CreateCall(
- logf_fn,
- {ir_builder_->CreateFMul(ir_builder_->CreateFSub(getFloat(1.0f), x),
- ir_builder_->CreateFAdd(getFloat(1.0f), x))}));
+ llvm::Value* w = b_->CreateFNeg(b_->CreateCall(
+ logf_fn, {b_->CreateFMul(b_->CreateFSub(getFloat(1.0f), x),
+ b_->CreateFAdd(getFloat(1.0f), x))}));
- llvm::Value* p_addr = llvm_ir::EmitAllocaAtFunctionEntry(
- ir_builder_->getFloatTy(), "p.addr", ir_builder_);
+ llvm::Value* p_addr =
+ llvm_ir::EmitAllocaAtFunctionEntry(b_->getFloatTy(), "p.addr", b_);
- llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(ir_builder_->CreateFCmpOLT(w, getFloat(5.0f)),
- "w_less_than_five", ir_builder_);
+ llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
+ b_->CreateFCmpOLT(w, getFloat(5.0f)), "w_less_than_five", b_);
// Handle true BB.
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, b_);
{
- llvm::Value* lw = ir_builder_->CreateFSub(w, getFloat(2.5f));
+ llvm::Value* lw = b_->CreateFSub(w, getFloat(2.5f));
tensorflow::gtl::ArraySlice<float> lq{
2.81022636e-08f, 3.43273939e-07f, -3.5233877e-06f,
-4.39150654e-06f, 0.00021858087f, -0.00125372503f,
-0.00417768164f, 0.246640727f, 1.50140941f};
llvm::Value* p = multiply_add(lq, lw);
- ir_builder_->CreateStore(p, p_addr);
+ b_->CreateStore(p, p_addr);
}
// Handle false BB.
- SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
{
llvm::Function* sqrtf_fn = llvm::Intrinsic::getDeclaration(
- module_, llvm::Intrinsic::sqrt, {ir_builder_->getFloatTy()});
+ module_, llvm::Intrinsic::sqrt, {b_->getFloatTy()});
- llvm::Value* gw = ir_builder_->CreateFSub(
- ir_builder_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
+ llvm::Value* gw =
+ b_->CreateFSub(b_->CreateCall(sqrtf_fn, {w}), getFloat(3.0f));
tensorflow::gtl::ArraySlice<float> gq{
-0.000200214257f, 0.000100950558f, 0.00134934322f,
-0.00367342844f, 0.00573950773f, -0.0076224613f,
0.00943887047f, 1.00167406f, 2.83297682f};
llvm::Value* p = multiply_add(gq, gw);
- ir_builder_->CreateStore(p, p_addr);
+ b_->CreateStore(p, p_addr);
}
- SetToFirstInsertPoint(if_data.after_block, ir_builder_);
- llvm::Value* p = ir_builder_->CreateLoad(p_addr);
- return ir_builder_->CreateFMul(p, x);
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ llvm::Value* p = b_->CreateLoad(p_addr);
+ return b_->CreateFMul(p, x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
@@ -998,13 +976,13 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitErfcInv(
// Compute erfcinv(value) by calculating erfinv(1.0 - value).
auto type = llvm_ir::PrimitiveTypeToIrType(prim_type, module_);
auto one = llvm::ConstantFP::get(type, 1.0);
- return EmitErfInv(prim_type, ir_builder_->CreateFSub(one, value));
+ return EmitErfInv(prim_type, b_->CreateFSub(one, value));
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog(PrimitiveType prim_type,
llvm::Value* value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::log, {value},
- {value->getType()}, ir_builder_);
+ {value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
@@ -1016,35 +994,34 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitLog1p(PrimitiveType prim_type,
// When x is large, the naive evaluation of ln(x + 1) is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto for_large_x,
- EmitLog(prim_type, ir_builder_->CreateFAdd(x, one)));
+ EmitLog(prim_type, b_->CreateFAdd(x, one)));
// The Taylor series for ln(x+1) is x - x^2/2 - x^3/3 + ….
- auto for_small_x = ir_builder_->CreateFMul(
- ir_builder_->CreateFAdd(ir_builder_->CreateFMul(negative_half, x), one),
- x);
+ auto for_small_x =
+ b_->CreateFMul(b_->CreateFAdd(b_->CreateFMul(negative_half, x), one), x);
const auto kAntilogarithmIsSmallThreshold = 1e-4;
- auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value},
- {type}, ir_builder_);
- auto x_is_small = ir_builder_->CreateFCmpOLT(
+ auto abs_x =
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
+ auto x_is_small = b_->CreateFCmpOLT(
abs_x, llvm::ConstantFP::get(type, kAntilogarithmIsSmallThreshold));
- return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitSin(PrimitiveType prim_type,
llvm::Value* value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::sin, {value},
- {value->getType()}, ir_builder_);
+ {value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitCos(PrimitiveType prim_type,
llvm::Value* value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::cos, {value},
- {value->getType()}, ir_builder_);
+ {value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExp(PrimitiveType prim_type,
llvm::Value* value) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::exp, {value},
- {value->getType()}, ir_builder_);
+ {value->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
@@ -1056,25 +1033,25 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitExpm1(PrimitiveType prim_type,
// When the exponent is large, the naive evaluation of e^(x) - 1 is more
// accurate than the Taylor series.
TF_ASSIGN_OR_RETURN(auto exp_x, EmitExp(prim_type, value));
- auto for_large_x = ir_builder_->CreateFSub(exp_x, one);
+ auto for_large_x = b_->CreateFSub(exp_x, one);
// The Taylor series for exp(x) is 1 + x + x^2/2 + x^3/6 + ….
// We want exp(x)-1 which is x + x^2/2 + x^3/6 + ….
- auto x_squared = ir_builder_->CreateFAdd(x, x);
- auto x_squared_over_two = ir_builder_->CreateFMul(x_squared, half);
- auto for_small_x = ir_builder_->CreateFAdd(x, x_squared_over_two);
+ auto x_squared = b_->CreateFAdd(x, x);
+ auto x_squared_over_two = b_->CreateFMul(x_squared, half);
+ auto for_small_x = b_->CreateFAdd(x, x_squared_over_two);
const auto kExponentIsSmallThreshold = 1e-5;
- auto abs_x = llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value},
- {type}, ir_builder_);
- auto x_is_small = ir_builder_->CreateFCmpOLT(
+ auto abs_x =
+ llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::fabs, {value}, {type}, b_);
+ auto x_is_small = b_->CreateFCmpOLT(
abs_x, llvm::ConstantFP::get(type, kExponentIsSmallThreshold));
- return ir_builder_->CreateSelect(x_is_small, for_small_x, for_large_x);
+ return b_->CreateSelect(x_is_small, for_small_x, for_large_x);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitPow(PrimitiveType prim_type,
llvm::Value* lhs,
llvm::Value* rhs) const {
return llvm_ir::EmitCallToIntrinsic(llvm::Intrinsic::pow, {lhs, rhs},
- {lhs->getType()}, ir_builder_);
+ {lhs->getType()}, b_);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitAtan2(PrimitiveType prim_type,
@@ -1089,11 +1066,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitReducePrecision(
return Unimplemented("reduce-precision only implemented for F32");
}
return EmitReducePrecisionFloat(x, /*exponent_bits=*/hlo->exponent_bits(),
- /*mantissa_bits=*/hlo->mantissa_bits(),
- ir_builder_);
+ /*mantissa_bits=*/hlo->mantissa_bits(), b_);
}
-static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder,
+static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
llvm::Value* lhs, llvm::Value* rhs,
llvm::Value* shift_result,
bool saturate_to_sign_bit) {
@@ -1106,15 +1082,14 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* ir_builder,
llvm::ConstantInt* minus_one = llvm::ConstantInt::get(integer_type, -1);
llvm::Value* saturated_value;
if (saturate_to_sign_bit) {
- saturated_value = ir_builder->CreateSelect(
- ir_builder->CreateICmpSLT(lhs, zero), minus_one, zero);
+ saturated_value =
+ b->CreateSelect(b->CreateICmpSLT(lhs, zero), minus_one, zero);
} else {
saturated_value = zero;
}
llvm::Value* shift_amt_in_range =
- ir_builder->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
- return ir_builder->CreateSelect(shift_amt_in_range, shift_result,
- saturated_value);
+ b->CreateICmpULT(rhs, integer_bitsize_constant, "shft.chk");
+ return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
@@ -1123,49 +1098,49 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
switch (op->opcode()) {
// TODO(jingyue): add the "nsw" attribute for signed types.
case HloOpcode::kAdd:
- return ir_builder_->CreateAdd(lhs_value, rhs_value);
+ return b_->CreateAdd(lhs_value, rhs_value);
case HloOpcode::kSubtract:
- return ir_builder_->CreateSub(lhs_value, rhs_value);
+ return b_->CreateSub(lhs_value, rhs_value);
case HloOpcode::kMultiply:
- return ir_builder_->CreateMul(lhs_value, rhs_value);
+ return b_->CreateMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return is_signed ? ir_builder_->CreateSDiv(lhs_value, rhs_value)
- : ir_builder_->CreateUDiv(lhs_value, rhs_value);
+ return is_signed ? b_->CreateSDiv(lhs_value, rhs_value)
+ : b_->CreateUDiv(lhs_value, rhs_value);
case HloOpcode::kRemainder:
- return is_signed ? ir_builder_->CreateSRem(lhs_value, rhs_value)
- : ir_builder_->CreateURem(lhs_value, rhs_value);
+ return is_signed ? b_->CreateSRem(lhs_value, rhs_value)
+ : b_->CreateURem(lhs_value, rhs_value);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kNe:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_NE, lhs_value,
- rhs_value, ir_builder_);
+ rhs_value, b_);
case HloOpcode::kLt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLT : llvm::CmpInst::ICMP_ULT,
- lhs_value, rhs_value, ir_builder_);
+ lhs_value, rhs_value, b_);
case HloOpcode::kGt:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGT : llvm::CmpInst::ICMP_UGT,
- lhs_value, rhs_value, ir_builder_);
+ lhs_value, rhs_value, b_);
case HloOpcode::kLe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SLE : llvm::CmpInst::ICMP_ULE,
- lhs_value, rhs_value, ir_builder_);
+ lhs_value, rhs_value, b_);
case HloOpcode::kGe:
return llvm_ir::EmitComparison(
is_signed ? llvm::CmpInst::ICMP_SGE : llvm::CmpInst::ICMP_UGE,
- lhs_value, rhs_value, ir_builder_);
+ lhs_value, rhs_value, b_);
case HloOpcode::kMinimum:
return EmitIntegralMin(lhs_value, rhs_value, is_signed);
case HloOpcode::kMaximum:
return EmitIntegralMax(lhs_value, rhs_value, is_signed);
case HloOpcode::kAnd:
- return ir_builder_->CreateAnd(lhs_value, rhs_value);
+ return b_->CreateAnd(lhs_value, rhs_value);
case HloOpcode::kOr:
- return ir_builder_->CreateOr(lhs_value, rhs_value);
+ return b_->CreateOr(lhs_value, rhs_value);
case HloOpcode::kXor:
- return ir_builder_->CreateXor(lhs_value, rhs_value);
+ return b_->CreateXor(lhs_value, rhs_value);
// Shifting out bits >= the number of bits in the type being shifted
// produces a poison value in LLVM which is basically "deferred undefined
@@ -1173,20 +1148,17 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
// UB. We replace the poison value with a constant to avoid this deferred
// UB.
case HloOpcode::kShiftRightArithmetic:
- return SaturateShiftIfNecessary(
- ir_builder_, lhs_value, rhs_value,
- ir_builder_->CreateAShr(lhs_value, rhs_value),
- /*saturate_to_sign_bit=*/true);
+ return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
+ b_->CreateAShr(lhs_value, rhs_value),
+ /*saturate_to_sign_bit=*/true);
case HloOpcode::kShiftLeft:
- return SaturateShiftIfNecessary(
- ir_builder_, lhs_value, rhs_value,
- ir_builder_->CreateShl(lhs_value, rhs_value),
- /*saturate_to_sign_bit=*/false);
+ return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
+ b_->CreateShl(lhs_value, rhs_value),
+ /*saturate_to_sign_bit=*/false);
case HloOpcode::kShiftRightLogical:
- return SaturateShiftIfNecessary(
- ir_builder_, lhs_value, rhs_value,
- ir_builder_->CreateLShr(lhs_value, rhs_value),
- /*saturate_to_sign_bit=*/false);
+ return SaturateShiftIfNecessary(b_, lhs_value, rhs_value,
+ b_->CreateLShr(lhs_value, rhs_value),
+ /*saturate_to_sign_bit=*/false);
default:
return Unimplemented("binary integer op '%s'",
HloOpcodeString(op->opcode()).c_str());
@@ -1196,21 +1168,19 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
llvm::Value* ElementalIrEmitter::EmitIntegralMax(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
- return ir_builder_->CreateSelect(
- ir_builder_->CreateICmp(
- is_signed ? llvm::ICmpInst::ICMP_SGE : llvm::ICmpInst::ICMP_UGE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SGE
+ : llvm::ICmpInst::ICMP_UGE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm::Value* ElementalIrEmitter::EmitIntegralMin(llvm::Value* lhs_value,
llvm::Value* rhs_value,
bool is_signed) const {
- return ir_builder_->CreateSelect(
- ir_builder_->CreateICmp(
- is_signed ? llvm::ICmpInst::ICMP_SLE : llvm::ICmpInst::ICMP_ULE,
- lhs_value, rhs_value),
- lhs_value, rhs_value);
+ return b_->CreateSelect(b_->CreateICmp(is_signed ? llvm::ICmpInst::ICMP_SLE
+ : llvm::ICmpInst::ICMP_ULE,
+ lhs_value, rhs_value),
+ lhs_value, rhs_value);
}
llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
@@ -1227,7 +1197,14 @@ llvm_ir::IrArray::Index ElementalIrEmitter::ElementwiseSourceIndex(
// If no implicit broadcast is needed for this operand, returns the target
// index as the source index.
- if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape())) {
+ //
+ // `IrArray::Index` may contain a physical linear which we can propagate to
+ // our operand only if our layouts match. "only if" is a bit strong since
+ // e.g. we can still forward the linear index if the operand shape is
+ // [5,1,1,5]{3,2,1,0} and the HLO shape is[5,1,1,5]{3,1,2,0}, but those cases
+ // are probably not worth handling here for now.
+ if (ShapeUtil::CompatibleIgnoringElementType(operand_shape, hlo.shape()) &&
+ LayoutUtil::Equal(operand_shape.layout(), hlo.shape().layout())) {
return target_index;
}
@@ -1256,10 +1233,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
// Same values as PCG library
// https://github.com/imneme/pcg-c/blob/master/include/pcg_variants.h
- llvm::Value* multiplier = ir_builder_->getInt(
- llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
- llvm::Value* increment = ir_builder_->getInt(
- llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
+ llvm::Value* multiplier =
+ b_->getInt(llvm::APInt(128, {0x4385DF649FCCF645, 0x2360ED051FC65DA4}));
+ llvm::Value* increment =
+ b_->getInt(llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
auto random_value_from_hlo = [hlo]() {
const HloModule* module =
@@ -1280,10 +1257,10 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
// values.
llvm::GlobalVariable* state_ptr0 = new llvm::GlobalVariable(
/*M=*/*module_,
- /*Ty=*/ir_builder_->getInt64Ty(),
+ /*Ty=*/b_->getInt64Ty(),
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()),
+ /*Initializer=*/b_->getInt64(random_value_from_hlo()),
/*Name=*/"state_ptr0");
// When the module config seed is 0, the expected result of a prng is a random
@@ -1294,17 +1271,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
: GlobalRandomValue();
llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
/*M=*/*module_,
- /*Ty=*/ir_builder_->getInt64Ty(),
+ /*Ty=*/b_->getInt64Ty(),
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/ir_builder_->getInt64(graph_seed),
+ /*Initializer=*/b_->getInt64(graph_seed),
/*Name=*/"state_ptr1");
// We want each thread to use its own stream, so we modify the increment per
// thread. We want the increment to remain odd, so we shift the thread id left
// 1 and add it to the increment.
- increment = ir_builder_->CreateAdd(increment,
- ir_builder_->CreateShl(EmitThreadId(), 1));
+ increment = b_->CreateAdd(increment, b_->CreateShl(EmitThreadId(), 1));
// PCG-XSL-RR algorithm
// http://www.pcg-random.org/pdf/toms-oneill-pcg-family-v1.02.pdf
@@ -1312,38 +1288,29 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
// return uint64_t(state ^ (state >> 64))) >>> (state >> 122)
// where ">>>" is bitwise rotation
auto get_next_i64 = [=]() {
- llvm::Value* state0 = ir_builder_->CreateZExtOrTrunc(
- ir_builder_->CreateLoad(state_ptr0, "state0"),
- ir_builder_->getInt128Ty());
- llvm::Value* state1 = ir_builder_->CreateShl(
- ir_builder_->CreateZExtOrTrunc(
- ir_builder_->CreateLoad(state_ptr1, "state1"),
- ir_builder_->getInt128Ty()),
+ llvm::Value* state0 = b_->CreateZExtOrTrunc(
+ b_->CreateLoad(state_ptr0, "state0"), b_->getInt128Ty());
+ llvm::Value* state1 = b_->CreateShl(
+ b_->CreateZExtOrTrunc(b_->CreateLoad(state_ptr1, "state1"),
+ b_->getInt128Ty()),
64);
- llvm::Value* state = ir_builder_->CreateOr(state0, state1);
- llvm::Value* updated = ir_builder_->CreateAdd(
- ir_builder_->CreateMul(state, multiplier), increment);
- ir_builder_->CreateStore(
- ir_builder_->CreateTrunc(updated, ir_builder_->getInt64Ty()),
- state_ptr0);
- ir_builder_->CreateStore(
- ir_builder_->CreateTrunc(ir_builder_->CreateLShr(updated, 64),
- ir_builder_->getInt64Ty()),
+ llvm::Value* state = b_->CreateOr(state0, state1);
+ llvm::Value* updated =
+ b_->CreateAdd(b_->CreateMul(state, multiplier), increment);
+ b_->CreateStore(b_->CreateTrunc(updated, b_->getInt64Ty()), state_ptr0);
+ b_->CreateStore(
+ b_->CreateTrunc(b_->CreateLShr(updated, 64), b_->getInt64Ty()),
state_ptr1);
return llvm_ir::CreateRor(
- ir_builder_->CreateTrunc(
- ir_builder_->CreateXor(state, ir_builder_->CreateLShr(state, 64)),
- ir_builder_->getInt64Ty()),
- ir_builder_->CreateTrunc(ir_builder_->CreateLShr(state, 122),
- ir_builder_->getInt64Ty()),
- ir_builder_);
+ b_->CreateTrunc(b_->CreateXor(state, b_->CreateLShr(state, 64)),
+ b_->getInt64Ty()),
+ b_->CreateTrunc(b_->CreateLShr(state, 122), b_->getInt64Ty()), b_);
};
auto get_next_uniform_float = [=]() {
- return ir_builder_->CreateFDiv(
- ir_builder_->CreateUIToFP(get_next_i64(), param_ir_type),
- llvm::ConstantFP::get(param_ir_type, 0x1p64));
+ return b_->CreateFDiv(b_->CreateUIToFP(get_next_i64(), param_ir_type),
+ llvm::ConstantFP::get(param_ir_type, 0x1p64));
};
return [=](const llvm_ir::IrArray::Index& index) -> StatusOr<llvm::Value*> {
@@ -1354,52 +1321,50 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
TF_ASSIGN_OR_RETURN(llvm::Value * q,
operand_to_generator.at(hlo->operand(1))(index));
if (primitive_util::IsFloatingPointType(param_prim_type)) {
- return ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(ir_builder_->CreateFSub(q, p),
- get_next_uniform_float()),
+ return b_->CreateFAdd(
+ b_->CreateFMul(b_->CreateFSub(q, p), get_next_uniform_float()),
p);
} else {
- auto r = ir_builder_->CreateSub(q, p);
+ auto r = b_->CreateSub(q, p);
auto leading_zeros = llvm_ir::EmitCallToIntrinsic(
- llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(true)},
- {param_ir_type}, ir_builder_);
- auto in_block = ir_builder_->GetInsertBlock();
+ llvm::Intrinsic::ctlz, {r, b_->getInt1(true)}, {param_ir_type},
+ b_);
+ auto in_block = b_->GetInsertBlock();
// A terminator should be present iff we're emitting code
// into the middle (as opposed to the end) of a basic block.
- CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(),
+ CHECK_EQ(b_->GetInsertPoint() == in_block->end(),
in_block->getTerminator() == nullptr);
llvm::BasicBlock* body_block;
llvm::BasicBlock* out_block;
- if (ir_builder_->GetInsertPoint() == in_block->end()) {
- body_block = llvm_ir::CreateBasicBlock(
- nullptr, IrName(hlo, "rng_body"), ir_builder_);
- out_block = llvm_ir::CreateBasicBlock(
- nullptr, IrName(hlo, "rng_out"), ir_builder_);
+ if (b_->GetInsertPoint() == in_block->end()) {
+ body_block =
+ llvm_ir::CreateBasicBlock(nullptr, IrName(hlo, "rng_body"), b_);
+ out_block =
+ llvm_ir::CreateBasicBlock(nullptr, IrName(hlo, "rng_out"), b_);
llvm::BranchInst::Create(body_block, in_block);
} else {
- body_block = in_block->splitBasicBlock(
- ir_builder_->GetInsertPoint(), "rng_body");
- out_block = body_block->splitBasicBlock(
- ir_builder_->GetInsertPoint(), "rng_out");
+ body_block =
+ in_block->splitBasicBlock(b_->GetInsertPoint(), "rng_body");
+ out_block =
+ body_block->splitBasicBlock(b_->GetInsertPoint(), "rng_out");
body_block->getTerminator()->eraseFromParent();
}
- SetToFirstInsertPoint(body_block, ir_builder_);
- auto random = ir_builder_->CreateAnd(
- ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
- ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
- leading_zeros));
+ SetToFirstInsertPoint(body_block, b_);
+ auto random = b_->CreateAnd(
+ b_->CreateZExtOrTrunc(get_next_i64(), param_ir_type),
+ b_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0),
+ leading_zeros));
llvm::BranchInst::Create(out_block, body_block,
- ir_builder_->CreateICmpULT(random, r),
- body_block);
- SetToFirstInsertPoint(out_block, ir_builder_);
- return ir_builder_->CreateAdd(
- p, ir_builder_->CreateSelect(
- ir_builder_->CreateICmpEQ(p, q),
- llvm::ConstantInt::get(param_ir_type, 0), random));
+ b_->CreateICmpULT(random, r), body_block);
+ SetToFirstInsertPoint(out_block, b_);
+ return b_->CreateAdd(
+ p, b_->CreateSelect(b_->CreateICmpEQ(p, q),
+ llvm::ConstantInt::get(param_ir_type, 0),
+ random));
}
}
case RNG_NORMAL: {
@@ -1409,11 +1374,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
operand_to_generator.at(hlo->operand(1))(index));
TF_ASSIGN_OR_RETURN(
llvm::Value * r,
- EmitErfcInv(param_prim_type,
- ir_builder_->CreateFMul(
- llvm::ConstantFP::get(param_ir_type, 2.0),
- get_next_uniform_float())));
- return ir_builder_->CreateFAdd(ir_builder_->CreateFMul(r, s), m);
+ EmitErfcInv(
+ param_prim_type,
+ b_->CreateFMul(llvm::ConstantFP::get(param_ir_type, 2.0),
+ get_next_uniform_float())));
+ return b_->CreateFAdd(b_->CreateFMul(r, s), m);
}
default:
return InvalidArgument(
@@ -1436,9 +1401,8 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalSelect(
TF_ASSIGN_OR_RETURN(llvm::Value * on_false_value,
operand_to_generator.at(hlo->operand(2))(
ElementwiseSourceIndex(index, *hlo, 2)));
- return ir_builder_->CreateSelect(
- ir_builder_->CreateTrunc(pred_value, ir_builder_->getInt1Ty()),
- on_true_value, on_false_value);
+ return b_->CreateSelect(b_->CreateTrunc(pred_value, b_->getInt1Ty()),
+ on_true_value, on_false_value);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalClamp(
@@ -1474,64 +1438,62 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalConcatenate(
const int64 concat_dim = hlo->dimensions(0);
auto source_index = target_index;
- llvm::BasicBlock* init_block = ir_builder_->GetInsertBlock();
+ llvm::BasicBlock* init_block = b_->GetInsertBlock();
// A terminator should be present iff we're emitting code
// into the middle (as opposed to the end) of a basic block.
- CHECK_EQ(ir_builder_->GetInsertPoint() == init_block->end(),
+ CHECK_EQ(b_->GetInsertPoint() == init_block->end(),
init_block->getTerminator() == nullptr);
llvm::BasicBlock* exit_block;
- if (ir_builder_->GetInsertPoint() == init_block->end()) {
+ if (b_->GetInsertPoint() == init_block->end()) {
exit_block = llvm_ir::CreateBasicBlock(
- /*insert_before=*/nullptr, IrName(hlo, "merge"), ir_builder_);
+ /*insert_before=*/nullptr, IrName(hlo, "merge"), b_);
} else {
- exit_block = init_block->splitBasicBlock(ir_builder_->GetInsertPoint(),
+ exit_block = init_block->splitBasicBlock(b_->GetInsertPoint(),
AsStringRef(IrName(hlo, "merge")));
init_block->getTerminator()->eraseFromParent();
}
- llvm_ir::SetToFirstInsertPoint(exit_block, ir_builder_);
- llvm::PHINode* output = ir_builder_->CreatePHI(
+ llvm_ir::SetToFirstInsertPoint(exit_block, b_);
+ llvm::PHINode* output = b_->CreatePHI(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
hlo->operands().size());
- auto prior_insert_point = ir_builder_->GetInsertPoint();
+ auto prior_insert_point = b_->GetInsertPoint();
- ir_builder_->SetInsertPoint(init_block);
+ b_->SetInsertPoint(init_block);
for (int64 operand_idx = 0; operand_idx < hlo->operand_count();
++operand_idx) {
const HloInstruction* operand = hlo->operand(operand_idx);
auto true_block = llvm_ir::CreateBasicBlock(
- exit_block, StrCat("concat_index_from_operand", operand_idx),
- ir_builder_);
+ exit_block, StrCat("concat_index_from_operand", operand_idx), b_);
auto false_block = llvm_ir::CreateBasicBlock(
- exit_block, StrCat("concat_index_not_from_operand", operand_idx),
- ir_builder_);
+ exit_block, StrCat("concat_index_not_from_operand", operand_idx), b_);
auto concat_dim_size =
llvm::ConstantInt::get(source_index[concat_dim]->getType(),
operand->shape().dimensions(concat_dim));
- ir_builder_->CreateCondBr(
- ir_builder_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
+ b_->CreateCondBr(
+ b_->CreateICmpULT(source_index[concat_dim], concat_dim_size),
true_block, false_block);
// Create the terminator of the true block before calling operand
// generators, because they require non-degenerate basic blocks.
- ir_builder_->SetInsertPoint(
+ b_->SetInsertPoint(
llvm::BranchInst::Create(exit_block, /*InsertAtEnd=*/true_block));
TF_ASSIGN_OR_RETURN(llvm::Value * value,
operand_to_generator.at(operand)(source_index));
- output->addIncoming(value, ir_builder_->GetInsertBlock());
+ output->addIncoming(value, b_->GetInsertBlock());
// Subtract the size of the concat dimension of the current operand
// from the source index.
- ir_builder_->SetInsertPoint(false_block);
+ b_->SetInsertPoint(false_block);
source_index[concat_dim] =
- ir_builder_->CreateSub(source_index[concat_dim], concat_dim_size);
+ b_->CreateSub(source_index[concat_dim], concat_dim_size);
}
- ir_builder_->CreateUnreachable();
- ir_builder_->SetInsertPoint(exit_block, prior_insert_point);
+ b_->CreateUnreachable();
+ b_->SetInsertPoint(exit_block, prior_insert_point);
return output;
}
@@ -1555,22 +1517,16 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
// Clamp the start index so that the sliced portion fits in the operand:
// start_index = clamp(start_index, 0, operand_dim_size - output_dim_size)
+ start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
+ int64 largest_valid_start_index =
+ input_hlo->shape().dimensions(i) - hlo->shape().dimensions(i);
+ CHECK_GE(largest_valid_start_index, 0);
- // TODO(b/74360564): This is implementation defined behavior, but is
- // currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
- start_index_value =
- ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
- llvm::Value* operand_dim_size =
- index_typed_const(input_hlo->shape().dimensions(i));
- llvm::Value* output_dim_size =
- index_typed_const(hlo->shape().dimensions(i));
-
+ bool is_signed = ShapeUtil::ElementIsSigned(hlo->operand(1)->shape());
start_index_value = EmitIntegralMin(
- ir_builder_->CreateSub(operand_dim_size, output_dim_size),
- EmitIntegralMax(index_typed_const(0), start_index_value,
- /*is_signed=*/true),
- /*is_signed=*/true);
+ index_typed_const(largest_valid_start_index),
+ EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
+ is_signed);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
@@ -1581,7 +1537,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicSlice(
for (int64 i = 0; i < rank; ++i) {
// Emit IR which computes:
// input_index = start_index + offset_index
- input_index[i] = ir_builder_->CreateAdd(slice_start_index[i], index[i]);
+ input_index[i] = b_->CreateAdd(slice_start_index[i], index[i]);
}
return operand_to_generator.at(input_hlo)(input_index);
}
@@ -1603,19 +1559,22 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
llvm::Type* index_type = index.GetType();
// This is the index into `operand` that holds the element we want to
- // generate. This index "unsafe" as in the components in here may be
- // out of bounds.
- IrArray::Index unsafe_operand_index(index_type);
-
- // First copy in the window indices to unsafe_operand_index.
- for (int64 i = 0, e = operand_shape.dimensions_size(),
- unsafe_operand_index_dim = 0;
+ // generate.
+ IrArray::Index operand_index(index_type);
+
+ // First copy in the window indices to operand_index. Also collect a mapping
+ // from operand dimension to output window dimension. Elided window dimensions
+ // map to -1.
+ std::vector<int64> operand_to_output_dim(operand_shape.dimensions_size(), -1);
+ for (int64 i = 0, e = operand_shape.dimensions_size(), operand_index_dim = 0;
i < e; i++) {
if (c_binary_search(dim_numbers.elided_window_dims(), i)) {
- unsafe_operand_index.push_back(index.GetConstantWithIndexType(0));
+ operand_index.push_back(index.GetConstantWithIndexType(0));
} else {
- unsafe_operand_index.push_back(
- index[dim_numbers.output_window_dims(unsafe_operand_index_dim++)]);
+ int64 output_window_dim =
+ dim_numbers.output_window_dims(operand_index_dim++);
+ operand_to_output_dim[i] = output_window_dim;
+ operand_index.push_back(index[output_window_dim]);
}
}
@@ -1634,20 +1593,40 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
}
}
- auto add_to_unsafe_operand_index = [&](llvm::Value* index_component,
- int64 dim) {
+ auto add_to_operand_index = [&](llvm::Value* index_component, int64 dim) {
llvm::Value* gather_dim_component_extended =
- ir_builder_->CreateSExtOrTrunc(index_component, index_type);
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)] =
- ir_builder_->CreateAdd(
- unsafe_operand_index[dim_numbers.gather_dims_to_operand_dims(dim)],
- gather_dim_component_extended);
+ b_->CreateSExtOrTrunc(index_component, index_type);
+ int64 operand_dim = dim_numbers.gather_dims_to_operand_dims(dim);
+ int64 output_dim = operand_to_output_dim[operand_dim];
+ // If 'output_dim' is -1, it means 'operand_dim' is an elided window dim.
+ // This means we set the iteration index to 0, so for the purpose of the
+ // following calculations we can consider the output dimension size to be 1.
+ int64 output_dim_size =
+ output_dim == -1 ? 1 : output_shape.dimensions(output_dim);
+ int64 largest_valid_start_index =
+ operand_shape.dimensions(operand_dim) - output_dim_size;
+ CHECK_GE(largest_valid_start_index, 0);
+
+ // Clamp the gather index so that the gather region fits in the operand.
+ // gather_dim_component_extended_inbound =
+ // clamp(gather_dim_component_extended, 0, largest_valid_start_index);
+
+ // TODO(b/111078873): This is implementation defined behavior.
+ bool is_signed = ShapeUtil::ElementIsSigned(indices_shape);
+ auto gather_dim_component_extended_inbound = EmitIntegralMin(
+ index.GetConstantWithIndexType(largest_valid_start_index),
+ EmitIntegralMax(index.GetConstantWithIndexType(0),
+ gather_dim_component_extended, is_signed),
+ is_signed);
+
+ operand_index[operand_dim] = b_->CreateAdd(
+ operand_index[operand_dim], gather_dim_component_extended_inbound);
};
if (indices_shape.dimensions_size() == dim_numbers.index_vector_dim()) {
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, 0);
+ add_to_operand_index(gather_dim_component, 0);
} else {
int64 index_vector_size =
indices_shape.dimensions(dim_numbers.index_vector_dim());
@@ -1656,18 +1635,10 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalGather(
index.GetConstantWithIndexType(i);
TF_ASSIGN_OR_RETURN(llvm::Value * gather_dim_component,
indices_generator(gather_index_index));
- add_to_unsafe_operand_index(gather_dim_component, i);
+ add_to_operand_index(gather_dim_component, i);
}
}
-
- IrArray::Index safe_operand_index(index_type);
- for (int64 i = 0, e = unsafe_operand_index.size(); i < e; i++) {
- safe_operand_index.push_back(ir_builder_->CreateURem(
- unsafe_operand_index[i],
- index.GetConstantWithIndexType(operand_shape.dimensions(i))));
- }
-
- return operand_generator(safe_operand_index);
+ return operand_generator(operand_index);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
@@ -1683,7 +1654,7 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
llvm_ir::IrArray::Index slice_limit_index(index.GetType(), rank);
// Slice intersection gathers (ANDs) conditions on all ranks for which
// 'input' is set to 'update'
- llvm::Value* slice_intersection = ir_builder_->getTrue();
+ llvm::Value* slice_intersection = b_->getTrue();
for (int64 i = 0; i < rank; ++i) {
llvm::Type* index_type = index[0]->getType();
@@ -1696,36 +1667,29 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// Clamp the start index so that the update region fits in the operand.
// start_index = clamp(start_index, 0, input_dim_size - update_dim_size)
-
- // TODO(b/74360564): This is implementation defined behavior, but is
- // currently respected by all implementations. Change this if we ever decide
- // to oficially document different behavior.
- start_index_value =
- ir_builder_->CreateSExtOrTrunc(start_index_value, index_type);
- llvm::Value* input_dim_size =
- index_typed_const(input_hlo->shape().dimensions(i));
+ start_index_value = b_->CreateSExtOrTrunc(start_index_value, index_type);
llvm::Value* update_dim_size =
index_typed_const(update_hlo->shape().dimensions(i));
+ int64 largest_valid_start_index =
+ input_hlo->shape().dimensions(i) - update_hlo->shape().dimensions(i);
+ CHECK_GE(largest_valid_start_index, 0);
- start_index_value =
- EmitIntegralMin(ir_builder_->CreateSub(input_dim_size, update_dim_size),
- EmitIntegralMax(index_typed_const(0), start_index_value,
- /*is_signed=*/true),
- /*is_signed=*/true);
+ bool is_signed = ShapeUtil::ElementIsSigned(start_hlo->shape());
+ start_index_value = EmitIntegralMin(
+ index_typed_const(largest_valid_start_index),
+ EmitIntegralMax(index_typed_const(0), start_index_value, is_signed),
+ is_signed);
start_index_value->setName(
AsStringRef(IrName(hlo, StrCat("start_idx", i))));
slice_start_index[i] = start_index_value;
- slice_limit_index[i] =
- ir_builder_->CreateAdd(slice_start_index[i], update_dim_size);
+ slice_limit_index[i] = b_->CreateAdd(slice_start_index[i], update_dim_size);
- slice_intersection = ir_builder_->CreateAnd(
- slice_intersection,
- ir_builder_->CreateICmpSGE(index[i], slice_start_index[i]),
+ slice_intersection = b_->CreateAnd(
+ slice_intersection, b_->CreateICmpSGE(index[i], slice_start_index[i]),
"slice_intersection");
- slice_intersection = ir_builder_->CreateAnd(
- slice_intersection,
- ir_builder_->CreateICmpSLT(index[i], slice_limit_index[i]),
+ slice_intersection = b_->CreateAnd(
+ slice_intersection, b_->CreateICmpSLT(index[i], slice_limit_index[i]),
"slice_intersection");
}
@@ -1734,29 +1698,29 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDynamicUpdateSlice(
// else -> return data from 'input'.
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
- "ret_value_addr", ir_builder_);
- llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
- slice_intersection, "slice_intersection", ir_builder_);
+ "ret_value_addr", b_);
+ llvm_ir::LlvmIfData if_data =
+ llvm_ir::EmitIfThenElse(slice_intersection, "slice_intersection", b_);
// Handle true BB (return data from 'update')
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.true_block, b_);
// Compute update index for intersection case.
llvm_ir::IrArray::Index update_index(index.GetType(), rank);
for (int64 i = 0; i < rank; ++i) {
- update_index[i] = ir_builder_->CreateSub(index[i], slice_start_index[i]);
+ update_index[i] = b_->CreateSub(index[i], slice_start_index[i]);
}
TF_ASSIGN_OR_RETURN(llvm::Value * true_value,
operand_to_generator.at(update_hlo)(update_index));
- ir_builder_->CreateStore(true_value, ret_value_addr);
+ b_->CreateStore(true_value, ret_value_addr);
// Handle false BB (return data from 'input')
- SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * false_value,
operand_to_generator.at(input_hlo)(index));
- ir_builder_->CreateStore(false_value, ret_value_addr);
+ b_->CreateStore(false_value, ret_value_addr);
- SetToFirstInsertPoint(if_data.after_block, ir_builder_);
- return ir_builder_->CreateLoad(ret_value_addr);
+ SetToFirstInsertPoint(if_data.after_block, b_);
+ return b_->CreateLoad(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
@@ -1764,29 +1728,29 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
const ElementalIrEmitter::HloToElementGeneratorMap& operand_to_generator,
const llvm_ir::IrArray::Index& padded_index) const {
auto index = padded_index;
- llvm::Value* in_bounds = ir_builder_->getTrue();
+ llvm::Value* in_bounds = b_->getTrue();
for (size_t i = 0; i < index.size(); ++i) {
auto index_typed_const = [=](int64 n) {
return llvm::ConstantInt::get(index[i]->getType(), n);
};
const auto& pad_dim = hlo->padding_config().dimensions(i);
- index[i] = ir_builder_->CreateSub(
- index[i], index_typed_const(pad_dim.edge_padding_low()));
- in_bounds = ir_builder_->CreateAnd(
- in_bounds, ir_builder_->CreateICmpSGE(index[i], index_typed_const(0)),
- "in_bounds");
- in_bounds = ir_builder_->CreateAnd(
+ index[i] =
+ b_->CreateSub(index[i], index_typed_const(pad_dim.edge_padding_low()));
+ in_bounds = b_->CreateAnd(in_bounds,
+ b_->CreateICmpSGE(index[i], index_typed_const(0)),
+ "in_bounds");
+ in_bounds = b_->CreateAnd(
in_bounds,
- ir_builder_->CreateICmpEQ(
+ b_->CreateICmpEQ(
index_typed_const(0),
- ir_builder_->CreateURem(
- index[i], index_typed_const(pad_dim.interior_padding() + 1))),
+ b_->CreateURem(index[i],
+ index_typed_const(pad_dim.interior_padding() + 1))),
"in_bounds");
- index[i] = ir_builder_->CreateSDiv(
+ index[i] = b_->CreateSDiv(
index[i], index_typed_const(pad_dim.interior_padding() + 1));
- in_bounds = ir_builder_->CreateAnd(
+ in_bounds = b_->CreateAnd(
in_bounds,
- ir_builder_->CreateICmpSLT(
+ b_->CreateICmpSLT(
index[i],
index_typed_const(hlo->operand(0)->shape().dimensions(i))),
"in_bounds");
@@ -1799,26 +1763,26 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalPad(
// }
llvm::Value* ret_value_addr = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(hlo->shape().element_type(), module_),
- "pad_result_addr", ir_builder_);
+ "pad_result_addr", b_);
llvm_ir::LlvmIfData if_data =
- llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", ir_builder_);
- SetToFirstInsertPoint(if_data.true_block, ir_builder_);
+ llvm_ir::EmitIfThenElse(in_bounds, "in_bounds", b_);
+ SetToFirstInsertPoint(if_data.true_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * operand_value,
operand_to_generator.at(hlo->operand(0))(index));
- ir_builder_->CreateStore(operand_value, ret_value_addr);
+ b_->CreateStore(operand_value, ret_value_addr);
- SetToFirstInsertPoint(if_data.false_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.false_block, b_);
TF_ASSIGN_OR_RETURN(llvm::Value * padding_value,
operand_to_generator.at(hlo->operand(1))(
IrArray::Index(index.GetType())));
- ir_builder_->CreateStore(padding_value, ret_value_addr);
+ b_->CreateStore(padding_value, ret_value_addr);
- SetToFirstInsertPoint(if_data.after_block, ir_builder_);
+ SetToFirstInsertPoint(if_data.after_block, b_);
// Don't create phi(operand_value, padding_value) here, because invoking
// operand_to_generator may create new basic blocks, making the parent
// of operand_value or padding_value no longer a predecessor of
// if_data.after_block.
- return ir_builder_->CreateLoad(ret_value_addr);
+ return b_->CreateLoad(ret_value_addr);
}
StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
@@ -1842,21 +1806,20 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
return llvm::ConstantInt::get(index_type, c);
};
- std::unique_ptr<llvm_ir::ForLoop> inner_loop =
- llvm_ir::ForLoop::EmitForLoop(IrName(hlo, "inner"), index_typed_const(0),
- index_typed_const(contracted_dim_size),
- index_typed_const(1), ir_builder_);
+ std::unique_ptr<llvm_ir::ForLoop> inner_loop = llvm_ir::ForLoop::EmitForLoop(
+ IrName(hlo, "inner"), index_typed_const(0),
+ index_typed_const(contracted_dim_size), index_typed_const(1), b_);
- SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), ir_builder_);
+ SetToFirstInsertPoint(inner_loop->GetPreheaderBasicBlock(), b_);
PrimitiveType primitive_type = hlo->shape().element_type();
llvm::Type* primitive_type_llvm =
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_);
- llvm::Value* accumulator_alloca = llvm_ir::EmitAllocaAtFunctionEntry(
- primitive_type_llvm, "dot_acc", ir_builder_);
- ir_builder_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
- accumulator_alloca);
+ llvm::Value* accumulator_alloca =
+ llvm_ir::EmitAllocaAtFunctionEntry(primitive_type_llvm, "dot_acc", b_);
+ b_->CreateStore(llvm::Constant::getNullValue(primitive_type_llvm),
+ accumulator_alloca);
- SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), ir_builder_);
+ SetToFirstInsertPoint(inner_loop->GetBodyBasicBlock(), b_);
// This is the inner reduction loop for a dot operation that produces
// one element in the output. If the operands to the dot operation have
@@ -1876,43 +1839,36 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitElementalDot(
}
rhs_index.InsertAt(rhs_contracting_dim, inner_loop->GetIndVarValue());
- llvm::Value* current_accumulator =
- ir_builder_->CreateLoad(accumulator_alloca);
+ llvm::Value* current_accumulator = b_->CreateLoad(accumulator_alloca);
TF_ASSIGN_OR_RETURN(llvm::Value * lhs_value, lhs_generator(lhs_index));
TF_ASSIGN_OR_RETURN(llvm::Value * rhs_value, rhs_generator(rhs_index));
llvm::Value* next_accumulator;
if (primitive_util::IsComplexType(primitive_type)) {
- llvm::Value* product_real = ir_builder_->CreateFSub(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractReal(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractImag(rhs_value)));
- llvm::Value* product_imag = ir_builder_->CreateFAdd(
- ir_builder_->CreateFMul(EmitExtractReal(lhs_value),
- EmitExtractImag(rhs_value)),
- ir_builder_->CreateFMul(EmitExtractImag(lhs_value),
- EmitExtractReal(rhs_value)));
- next_accumulator = ir_builder_->CreateInsertValue(
+ llvm::Value* product_real = b_->CreateFSub(
+ b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractReal(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractImag(rhs_value)));
+ llvm::Value* product_imag = b_->CreateFAdd(
+ b_->CreateFMul(EmitExtractReal(lhs_value), EmitExtractImag(rhs_value)),
+ b_->CreateFMul(EmitExtractImag(lhs_value), EmitExtractReal(rhs_value)));
+ next_accumulator = b_->CreateInsertValue(
current_accumulator,
- ir_builder_->CreateFAdd(EmitExtractReal(current_accumulator),
- product_real),
+ b_->CreateFAdd(EmitExtractReal(current_accumulator), product_real),
{0});
- next_accumulator = ir_builder_->CreateInsertValue(
+ next_accumulator = b_->CreateInsertValue(
next_accumulator,
- ir_builder_->CreateFAdd(EmitExtractImag(current_accumulator),
- product_imag),
+ b_->CreateFAdd(EmitExtractImag(current_accumulator), product_imag),
{1});
} else if (primitive_util::IsFloatingPointType(primitive_type)) {
- next_accumulator = ir_builder_->CreateFAdd(
- current_accumulator, ir_builder_->CreateFMul(lhs_value, rhs_value));
+ next_accumulator = b_->CreateFAdd(current_accumulator,
+ b_->CreateFMul(lhs_value, rhs_value));
} else {
- next_accumulator = ir_builder_->CreateAdd(
- current_accumulator, ir_builder_->CreateMul(lhs_value, rhs_value));
+ next_accumulator =
+ b_->CreateAdd(current_accumulator, b_->CreateMul(lhs_value, rhs_value));
}
- ir_builder_->CreateStore(next_accumulator, accumulator_alloca);
+ b_->CreateStore(next_accumulator, accumulator_alloca);
- SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), ir_builder_);
- return ir_builder_->CreateLoad(accumulator_alloca);
+ SetToFirstInsertPoint(inner_loop->GetExitBasicBlock(), b_);
+ return b_->CreateLoad(accumulator_alloca);
}
llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
@@ -2012,7 +1968,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
auto source_index = target_index;
for (int64 dim : hlo->dimensions()) {
- source_index[dim] = ir_builder_->CreateSub(
+ source_index[dim] = b_->CreateSub(
llvm::ConstantInt::get(target_index[dim]->getType(),
hlo->shape().dimensions(dim) - 1),
target_index[dim]);
@@ -2025,16 +1981,16 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
const HloInstruction* operand = hlo->operand(0);
// The `dimensions` member of the broadcast instruction maps from
// input dimensions to output dimensions.
- return operand_to_generator.at(
- operand)(target_index.SourceIndexOfBroadcast(
- hlo->shape(), operand->shape(), hlo->dimensions(), ir_builder_));
+ return operand_to_generator.at(operand)(
+ target_index.SourceIndexOfBroadcast(hlo->shape(), operand->shape(),
+ hlo->dimensions(), b_));
};
case HloOpcode::kSlice:
return [this, hlo, &operand_to_generator](
const IrArray::Index& index) -> StatusOr<llvm::Value*> {
IrArray::Index sliced_index = index.SourceIndexOfSlice(
/*shape=*/hlo->shape(), /*starts=*/hlo->slice_starts(),
- /*strides=*/hlo->slice_strides(), /*builder=*/ir_builder_);
+ /*strides=*/hlo->slice_strides(), /*builder=*/b_);
return operand_to_generator.at(hlo->operand(0))(sliced_index);
};
case HloOpcode::kDynamicSlice:
@@ -2059,24 +2015,23 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
- return operand_to_generator.at(operand)(index.SourceIndexOfBitcast(
- hlo->shape(), operand->shape(), ir_builder_));
+ return operand_to_generator.at(operand)(
+ index.SourceIndexOfBitcast(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kReshape:
CHECK_EQ(ShapeUtil::ElementsIn(hlo->shape()),
ShapeUtil::ElementsIn(hlo->operand(0)->shape()));
return [this, hlo, &operand_to_generator](const IrArray::Index& index) {
const HloInstruction* operand = hlo->operand(0);
- return operand_to_generator.at(operand)(index.SourceIndexOfReshape(
- hlo->shape(), operand->shape(), ir_builder_));
+ return operand_to_generator.at(operand)(
+ index.SourceIndexOfReshape(hlo->shape(), operand->shape(), b_));
};
case HloOpcode::kTranspose:
return [this, hlo,
&operand_to_generator](const IrArray::Index& target_index) {
return operand_to_generator.at(hlo->operand(0))(
target_index.SourceIndexOfTranspose(
- hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(),
- ir_builder_));
+ hlo->shape(), hlo->operand(0)->shape(), hlo->dimensions(), b_));
};
case HloOpcode::kRng:
return MakeRngElementGenerator(hlo, operand_to_generator);
@@ -2101,11 +2056,11 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeElementGenerator(
}
llvm::Value* ElementalIrEmitter::EmitExtractReal(llvm::Value* value) const {
- return ir_builder_->CreateExtractValue(value, {0});
+ return b_->CreateExtractValue(value, {0});
}
llvm::Value* ElementalIrEmitter::EmitExtractImag(llvm::Value* value) const {
- return ir_builder_->CreateExtractValue(value, {1});
+ return b_->CreateExtractValue(value, {1});
}
llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
@@ -2113,10 +2068,10 @@ llvm::Value* ElementalIrEmitter::EmitComposeComplex(const HloInstruction* op,
llvm::Value* imag) const {
auto cplx_type =
llvm_ir::PrimitiveTypeToIrType(op->shape().element_type(), module_);
- auto complex = ir_builder_->CreateInsertValue(
+ auto complex = b_->CreateInsertValue(
llvm::ConstantAggregateZero::get(cplx_type), real, {0});
if (imag != nullptr) {
- complex = ir_builder_->CreateInsertValue(complex, imag, {1});
+ complex = b_->CreateInsertValue(complex, imag, {1});
}
return complex;
}