aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc95
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.h15
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc6
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md4
4 files changed, 112 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 4b19aa5df9..215af562a5 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -1100,6 +1100,95 @@ static llvm::Value* SaturateShiftIfNecessary(llvm::IRBuilder<>* b,
return b->CreateSelect(shift_amt_in_range, shift_result, saturated_value);
}
+llvm::Value* ElementalIrEmitter::GetOne(llvm::Type* type) const {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 1);
+}
+
+llvm::Value* ElementalIrEmitter::GetZero(llvm::Type* type) const {
+ return llvm::ConstantInt::get(llvm::cast<llvm::IntegerType>(type), 0);
+}
+
+llvm::Value* ElementalIrEmitter::GetIntSMin(llvm::Type* type) const {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(integer_type, llvm::APInt::getSignedMinValue(
+ integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::GetMinusOne(llvm::Type* type) const {
+ auto* integer_type = llvm::cast<llvm::IntegerType>(type);
+ return llvm::ConstantInt::get(
+ integer_type, llvm::APInt::getAllOnesValue(integer_type->getBitWidth()));
+}
+
+llvm::Value* ElementalIrEmitter::IsZero(llvm::Value* v) const {
+ return b_->CreateICmpEQ(v, llvm::ConstantInt::get(v->getType(), 0));
+}
+
+llvm::Value* ElementalIrEmitter::IsIntMinDivisionOverflow(
+ llvm::Value* lhs, llvm::Value* rhs) const {
+ return b_->CreateAnd(b_->CreateICmpEQ(lhs, GetIntSMin(lhs->getType())),
+ b_->CreateICmpEQ(rhs, GetMinusOne(rhs->getType())));
+}
+
+llvm::Value* ElementalIrEmitter::Select(llvm::Value* cond, llvm::Value* if_true,
+ llvm::Value* if_false) const {
+ return b_->CreateSelect(cond, if_true, if_false);
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerDivide(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) const {
+ // Integer division overflow behavior:
+ //
+ // X / 0 == -1
+ // INT_SMIN /s -1 = INT_SMIN
+
+ if (!is_signed) {
+ llvm::Value* udiv_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(udiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = b_->CreateUDiv(lhs, safe_rhs);
+ return Select(udiv_is_unsafe, GetMinusOne(lhs->getType()), safe_div);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* sdiv_is_unsafe =
+ b_->CreateOr(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(sdiv_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_div = b_->CreateSDiv(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, GetMinusOne(lhs->getType()),
+ Select(has_int_min_overflow, GetIntSMin(lhs->getType()), safe_div));
+}
+
+llvm::Value* ElementalIrEmitter::EmitIntegerRemainder(llvm::Value* lhs,
+ llvm::Value* rhs,
+ bool is_signed) const {
+ // Integer remainder overflow behavior:
+ //
+ // X % 0 == X
+ // INT_SMIN %s -1 = 0
+
+ if (!is_signed) {
+ llvm::Value* urem_is_unsafe = IsZero(rhs);
+ llvm::Value* safe_rhs = Select(urem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = b_->CreateURem(lhs, safe_rhs);
+ return Select(urem_is_unsafe, lhs, safe_rem);
+ }
+
+ llvm::Value* has_zero_divisor = IsZero(rhs);
+ llvm::Value* has_int_min_overflow = IsIntMinDivisionOverflow(lhs, rhs);
+ llvm::Value* srem_is_unsafe =
+ b_->CreateOr(has_int_min_overflow, has_zero_divisor);
+ llvm::Value* safe_rhs = Select(srem_is_unsafe, GetOne(lhs->getType()), rhs);
+ llvm::Value* safe_rem = b_->CreateSRem(lhs, safe_rhs);
+
+ return Select(
+ has_zero_divisor, lhs,
+ Select(has_int_min_overflow, GetZero(lhs->getType()), safe_rem));
+}
+
StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
const HloInstruction* op, llvm::Value* lhs_value, llvm::Value* rhs_value,
bool is_signed) const {
@@ -1112,11 +1201,9 @@ StatusOr<llvm::Value*> ElementalIrEmitter::EmitIntegerBinaryOp(
case HloOpcode::kMultiply:
return b_->CreateMul(lhs_value, rhs_value);
case HloOpcode::kDivide:
- return is_signed ? b_->CreateSDiv(lhs_value, rhs_value)
- : b_->CreateUDiv(lhs_value, rhs_value);
+ return EmitIntegerDivide(lhs_value, rhs_value, is_signed);
case HloOpcode::kRemainder:
- return is_signed ? b_->CreateSRem(lhs_value, rhs_value)
- : b_->CreateURem(lhs_value, rhs_value);
+ return EmitIntegerRemainder(lhs_value, rhs_value, is_signed);
case HloOpcode::kEq:
return llvm_ir::EmitComparison(llvm::CmpInst::ICMP_EQ, lhs_value,
rhs_value, b_);
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.h b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
index 1598a4dd85..c037b98929 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.h
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.h
@@ -65,6 +65,21 @@ class ElementalIrEmitter {
virtual StatusOr<llvm::Value*> EmitComplexUnaryOp(
const HloInstruction* op, llvm::Value* operand_value) const;
+ llvm::Value* IsZero(llvm::Value* v) const;
+ llvm::Value* IsIntMinDivisionOverflow(llvm::Value* lhs,
+ llvm::Value* rhs) const;
+ llvm::Value* GetZero(llvm::Type* type) const;
+ llvm::Value* GetOne(llvm::Type* type) const;
+ llvm::Value* GetIntSMin(llvm::Type* type) const;
+ llvm::Value* GetMinusOne(llvm::Type* type) const;
+ llvm::Value* Select(llvm::Value* cond, llvm::Value* if_true,
+ llvm::Value* if_false) const;
+
+ llvm::Value* EmitIntegerDivide(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed) const;
+ llvm::Value* EmitIntegerRemainder(llvm::Value* lhs, llvm::Value* rhs,
+ bool is_signed) const;
+
virtual StatusOr<llvm::Value*> EmitIntegerBinaryOp(const HloInstruction* op,
llvm::Value* lhs_value,
llvm::Value* rhs_value,
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 316ab26a1f..84c5b6e549 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -498,8 +498,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivS32s) {
TestDivRem<int32>(dividends, divisors, quotients, remainders);
}
-XLA_TEST_F(IntegerDivideOpTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(SignedOverflow))) {
+XLA_TEST_F(IntegerDivideOpTest, SignedOverflow) {
std::vector<int32> dividends = {5, INT32_MIN}, divisors = {0, -1},
quotients = {-1, INT32_MIN}, remainders = {5, 0};
@@ -529,8 +528,7 @@ XLA_TEST_F(IntegerDivideOpTest, DivU32s) {
TestDivRem<uint32>(dividends, divisors, quotients, remainders);
}
-XLA_TEST_F(IntegerDivideOpTest,
- DISABLED_ON_CPU(DISABLED_ON_GPU(UnsignedOverflow))) {
+XLA_TEST_F(IntegerDivideOpTest, UnsignedOverflow) {
std::vector<int32> dividends = {5}, divisors = {0}, quotients = {-1},
remainders = {5};
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 2de30d1b3d..c23a7ad9e2 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -1036,6 +1036,10 @@ different ranks are *not* supported, unless one of the operands is a scalar.
When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
absolute value of the result is always less than the divisor's absolute value.
+Integer division overflow (signed/unsigned division/remainder by zero or signed
+divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined
+value.
+
An alternative variant with different-rank broadcasting support exists for these
operations: