diff options
4 files changed, 112 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index 4b19aa5df9..215af562a5 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -1100,6 +1100,95 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b, return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value); } +llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) const { + return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1); +} + +llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) const { + return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0); +} + +llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) const { + auto* integer_type = llvm::cast<llvm::IntegerType>(type); + return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue( + integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) const { + auto* integer_type = llvm::cast<llvm::IntegerType>(type); + return llvm::ConstantInt::get( + integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth())); +} + +llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) const { + return b_->CreateICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0)); +} + +llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow( + llvm::Value* lhs, llvm::Value* rhs) const { + return b_->CreateAnd(b_->CreateICmpEQ(lhs, GetIntSMin(lhs->getType())), + b_->CreateICmpEQ(rhs, GetMinusOne(rhs->getType()))); +} + +llvm::Value* ElementalIrEmitter::Select(llvm::Value* cond, llvm::Value* if_true, + llvm::Value* if_false) const { + return b_->CreateSelect(cond, if_true, if_false); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) const { + // Integer division overflow behavior: + // + // X / 0 == -1 + // INT_SMIN /s -1 = INT_SMIN + + if (!is_signed) { + llvm::Value* udiv_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = b_->CreateUDiv(lhs, safe_rhs); + return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* sdiv_is_unsafe = + b_->CreateOr(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_div = b_->CreateSDiv(lhs, safe_rhs); + + return Select( + has_zero_divisor, GetMinusOne(lhs->getType()), + Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div)); +} + +llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs, + llvm::Value* rhs, + bool is_signed) const { + // Integer remainder overflow behavior: + // + // X % 0 == X + // INT_SMIN %s -1 = 0 + + if (!is_signed) { + llvm::Value* urem_is_unsafe = IsZero(rhs); + llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = b_->CreateURem(lhs, safe_rhs); + return Select(urem_is_unsafe, lhs, safe_rem); + } + + llvm::Value* has_zero_divisor = IsZero(rhs); + llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs); + llvm::Value* srem_is_unsafe = + b_->CreateOr(has_int_min_overflow, has_zero_divisor); + llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs); + llvm::Value* safe_rem = b_->CreateSRem(lhs, safe_rhs); + + return Select( + has_zero_divisor, lhs, + Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem)); +} + StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, bool is_signed) const { @@ -1112,11 +1201,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp( case HloOpcode::kMultiply: return b_->CreateMul(lhs_value, rhs_value); case HloOpcode::kDivide: - return is_signed ? b_->CreateSDiv(lhs_value, rhs_value) - : b_->CreateUDiv(lhs_value, rhs_value); + return EmitIntegerDivide(lhs_value, rhs_value, is_signed); case HloOpcode::kRemainder: - return is_signed ? b_->CreateSRem(lhs_value, rhs_value) - : b_->CreateURem(lhs_value, rhs_value); + return EmitIntegerRemainder(lhs_value, rhs_value, is_signed); case HloOpcode::kEq: return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value, rhs_value, b_); diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h index 1598a4dd85..c037b98929 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h @@ -65,6 +65,21 @@ class ElementalIrEmitter { virtual StatusOr<llvm::Value*> EmitComplexUnaryOp( const HloInstruction* op, llvm::Value* operand_value) const; + llvm::Value* IsZero(llvm::Value* v) const; + llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs, + llvm::Value* rhs) const; + llvm::Value* GetZero(llvm::Type* type) const; + llvm::Value* GetOne(llvm::Type* type) const; + llvm::Value* GetIntSMin(llvm::Type* type) const; + llvm::Value* GetMinusOne(llvm::Type* type) const; + llvm::Value* Select(llvm::Value* cond, llvm::Value* if_true, + llvm::Value* if_false) const; + + llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed) const; + llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs, + bool is_signed) const; + virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value, diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc index 316ab26a1f..84c5b6e549 100644 --- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc +++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc @@ -498,8 +498,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivS32s) { TestDivRem<int32>(dividends, divisors, quotients, remainders); } -XLA_TEST_F(IntegerDivideOpTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(SignedOverflow))) { +XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) { std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1}, quotients = {-1, INT32_MIN}, remainders = {5, 0}; @@ -529,8 +528,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivU32s) { TestDivRem<uint32>(dividends, divisors, quotients, remainders); } -XLA_TEST_F(IntegerDivideOpTest, - DISABLED_ON_CPU(DISABLED_ON_GPU(UnsignedOverflow))) { +XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) { std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1}, remainders = {5}; diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md index 2de30d1b3d..c23a7ad9e2 100644 --- a/tensorflow/docs_src/performance/xla/operation_semantics.md +++ b/tensorflow/docs_src/performance/xla/operation_semantics.md @@ -1036,6 +1036,10 @@ different ranks are *not* supported, unless one of the operands is a scalar. When `Op` is `Rem`, the sign of the result is taken from the dividend, and the absolute value of the result is always less than the divisor's absolute value. +Integer division overflow (signed/unsigned division/remainder by zero or signed +divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined +value. + An alternative variant with different-rank broadcasting support exists for these operations: |