diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index bef7a55301..09486d291a 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -2080,9 +2080,9 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { // Launch a kernel that reads every element in the updates tensor. We could // also do one kernel per window instead if bounds checks turn out to be a // bottleneck. - thunks.push_back(BuildKernelThunk( - scatter, - /*implements_whole_instruction=*/operand_buffer == destination_buffer)); + thunks.push_back( + BuildKernelThunk(scatter, + /*implements_whole_instruction=*/thunks.empty())); LaunchDimensions launch_dimensions = CalculateLaunchDimensions( updates->shape(), ir_emitter_context_->device_description()); @@ -2090,8 +2090,12 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { static_cast<KernelThunk*>(thunks.back().get()), ir_emitter_context_->llvm_module()); - thunk_sequence_->emplace_back( - absl::make_unique<SequentialThunk>(std::move(thunks), scatter)); + if (thunks.size() == 1) { + thunk_sequence_->push_back(std::move(thunks[0])); + } else { + thunk_sequence_->emplace_back( + absl::make_unique<SequentialThunk>(std::move(thunks), scatter)); + } return ParallelLoopEmitter(loop_body_emitter, updates->shape(), launch_dimensions, &b_) .EmitLoop(IrName(scatter), |