diff options
author | 2017-06-13 17:49:59 -0700 | |
---|---|---|
committer | 2017-06-13 21:50:28 -0700 | |
commit | 75145524ffed31bb749a49c3dba4518590767cf6 (patch) | |
tree | 01b4e2dfc7e9653b7f7eee02448b83f5fb83ffab | |
parent | 6bf093ca016a2b1caf40e30f7cd73809ab3257f4 (diff) |
Don't split basic blocks that lack a terminator.
This replicates the idiom used in ForLoop::Emit to handle the same
situation into ElementalIrEmitter::MakeRngElementGenerator.
PiperOrigin-RevId: 158925513
-rw-r--r-- | tensorflow/compiler/xla/service/elemental_ir_emitter.cc | 35 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tests/prng_test.cc | 11 |
2 files changed, 37 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc index bea1da4044..dbc65e80eb 100644 --- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc +++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc @@ -588,20 +588,37 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator( llvm::Intrinsic::ctlz, {r, ir_builder_->getInt1(1)}, {param_ir_type}, ir_builder_); auto in_block = ir_builder_->GetInsertBlock(); - auto body_block = in_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_body"); - SetToFirstInsertPoint(body_block, ir_builder_); - auto out_block = body_block->splitBasicBlock( - ir_builder_->GetInsertPoint(), "rng_out"); + + // A terminator should be present iff we're emitting code + // into the middle (as opposed to the end) of a basic block. + CHECK_EQ(ir_builder_->GetInsertPoint() == in_block->end(), + in_block->getTerminator() == nullptr); + + llvm::BasicBlock* body_block; + llvm::BasicBlock* out_block; + + if (ir_builder_->GetInsertPoint() == in_block->end()) { + body_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_body", ir_builder_); + out_block = + llvm_ir::CreateBasicBlock(nullptr, "rng_out", ir_builder_); + llvm::BranchInst::Create(body_block, in_block); + } else { + body_block = in_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_body"); + out_block = body_block->splitBasicBlock( + ir_builder_->GetInsertPoint(), "rng_out"); + body_block->getTerminator()->eraseFromParent(); + } + SetToFirstInsertPoint(body_block, ir_builder_); auto random = ir_builder_->CreateAnd( ir_builder_->CreateZExtOrTrunc(get_next_i64(), param_ir_type), ir_builder_->CreateLShr(llvm::ConstantInt::get(param_ir_type, ~0), leading_zeros)); - llvm::ReplaceInstWithInst( - body_block->getTerminator(), - llvm::BranchInst::Create(out_block, body_block, - ir_builder_->CreateICmpULT(random, r))); + llvm::BranchInst::Create(out_block, body_block, + ir_builder_->CreateICmpULT(random, r), + body_block); SetToFirstInsertPoint(out_block, ir_builder_); return ir_builder_->CreateAdd( p, ir_builder_->CreateSelect( diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc index 5117478bfd..b77b8e2ee3 100644 --- a/tensorflow/compiler/xla/tests/prng_test.cc +++ b/tensorflow/compiler/xla/tests/prng_test.cc @@ -273,6 +273,17 @@ XLA_TEST_F(PrngTest, TenValuesN01) { // TODO(b/25995601): Test that resultant values are reasonable } +XLA_TEST_F(PrngTest, RngUniformCrash) { + ComputationBuilder builder(client_, TestName()); + + // This used to crash XLA during LLVM IR generation for CPUs. + auto rng_uniform = builder.RngUniform(builder.ConstantR0<int32>(0), + builder.ConstantR0<int32>(1000 * 1000), + ShapeUtil::MakeShape(S32, {})); + SetSeed(0); + ExecuteAndTransferOrDie(&builder, /*arguments=*/{}); +} + } // namespace } // namespace xla |