aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-09 16:52:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 17:02:47 -0700
commit65b7d0b2f84c334327a295bf41bc06c7f6b8ffe5 (patch)
tree7602939f063340ec4a98ffe5f8179fff7e3c1bd5
parentd4526cf9d1d58cbe480e7d2b8199620e0e9f0572 (diff)
[XLA:GPU] Elide the SequentialThunk when emitting scatter with no copy
We have a 1-element thunk sequence if we're not copying. That's still two thunks and hlo profiling gets confused if it sees two thunks for the same instruction and one of them claims to be the whole instruction. PiperOrigin-RevId: 216448063
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc14
1 files changed, 9 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index bef7a55301..09486d291a 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2080,9 +2080,9 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
// Launch a kernel that reads every element in the updates tensor. We could
// also do one kernel per window instead if bounds checks turn out to be a
// bottleneck.
- thunks.push_back(BuildKernelThunk(
- scatter,
- /*implements_whole_instruction=*/operand_buffer == destination_buffer));
+ thunks.push_back(
+ BuildKernelThunk(scatter,
+ /*implements_whole_instruction=*/thunks.empty()));
LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
updates->shape(), ir_emitter_context_->device_description());
@@ -2090,8 +2090,12 @@ Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
static_cast<KernelThunk*>(thunks.back().get()),
ir_emitter_context_->llvm_module());
- thunk_sequence_->emplace_back(
- absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
+ if (thunks.size() == 1) {
+ thunk_sequence_->push_back(std::move(thunks[0]));
+ } else {
+ thunk_sequence_->emplace_back(
+ absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
+ }
return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
launch_dimensions, &b_)
.EmitLoop(IrName(scatter),