aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bef7a55301..09486d291a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2080,9 +2080,9 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
// Launch a kernel that reads every element in the updates tensor. We could
// also do one kernel per window instead if bounds checks turn out to be a
// bottleneck.
- thunks.push_back(BuildKernelThunk(
- scatter,
- /*implements_whole_instruction=*/operand_buffer == destination_buffer));
+ thunks.push_back(
+ BuildKernelThunk(scatter,
+ /*implements_whole_instruction=*/thunks.empty()));
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
updates->shape(), ir_emitter_context_->device_description());
@@ -2090,8 +2090,12 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
static_cast<KernelThunk*>(thunks.back().get()),
ir_emitter_context_->llvm_module());
- thunk_sequence_->emplace_back(
- absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
+ if (thunks.size() == 1) {
+ thunk_sequence_->push_back(std::move(thunks[0]));
+ } else {
+ thunk_sequence_->emplace_back(
+ absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
+ }
return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
launch_dimensions, &b_)
.EmitLoop(IrName(scatter),