diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-10-09 13:32:24 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-09 13:40:43 -0700 |
commit | 5d9a7fdf4f02c2db487a03e7ad2d520f8847c4e3 (patch) | |
tree | a77d90f9328b7e0e859a15ab3b5d765774954b5a | |
parent | 9989788be25c846d087ac70b76cf78759a209a3e (diff) |
[XLA:GPU] Add an implementation of scatter for GPU
This simple has a kernel that runs on every element of the updates tensor,
figure out the right indices to perform the update, and applies it with an
atomic operation.
Currently we emit a CAS for plain (i.e. non-add) updates, which is inefficient.
Also TuplePointsToAnalysis doesn't know that it should alias the operand and
output buffers of a scatter, which would avoid a copy.
PiperOrigin-RevId: 216412467
5 files changed, 143 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index 350fd32537..0144d59097 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -705,7 +705,6 @@ cc_library( "//tensorflow/compiler/xla/service:llvm_compiler", "//tensorflow/compiler/xla/service:reduce_precision_insertion", "//tensorflow/compiler/xla/service:reshape_mover", - "//tensorflow/compiler/xla/service:scatter_expander", "//tensorflow/compiler/xla/service:transpose_folding", "//tensorflow/compiler/xla/service:tuple_simplifier", "//tensorflow/compiler/xla/service:while_loop_constant_sinking", diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index c792dd2ddb..bef7a55301 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -1958,6 +1958,147 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) { return Status::OK(); } +Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) { + const HloInstruction* operand = scatter->operand(0); + const HloInstruction* scatter_indices = scatter->operand(1); + const HloInstruction* updates = scatter->operand(2); + const ScatterDimensionNumbers& dim_numbers = + scatter->scatter_dimension_numbers(); + CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape())); + + std::vector<std::unique_ptr<Thunk>> thunks; + + // Copy the operand into the output if it's not the same buffer already. + auto operand_buffer = GetAllocationSlice(*operand); + auto destination_buffer = GetAllocationSlice(*scatter); + if (operand_buffer != destination_buffer) { + thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>( + /*source_address=*/operand_buffer, + /*destination_buffer=*/destination_buffer, + /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter)); + } + + auto loop_body_emitter = [&](const IrArray::Index& index) -> Status { + std::vector<llvm::Value*> raw_window_multidim; + std::vector<llvm::Value*> input_scatter_multidim; + std::vector<int64> raw_window_bounds; + + // Partition the index into window indices and scatter indices. + for (int64 i = 0, e = index.size(); i != e; ++i) { + // For window indices also remember the window size, this comes in handy + // later. + if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) { + raw_window_multidim.push_back(index[i]); + raw_window_bounds.push_back(updates->shape().dimensions(i)); + } else { + input_scatter_multidim.push_back(index[i]); + } + } + DCHECK_EQ(raw_window_multidim.size(), + dim_numbers.update_window_dims_size()); + + // Apply inserted_window_dims to the window dimensions. + int64 raw_window_multidim_idx = 0; + std::vector<llvm::Value*> input_window_multidim; + std::vector<int64> input_window_bounds; + for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) { + if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) { + input_window_bounds.push_back(1); // Trivial dimension. + input_window_multidim.push_back(index.GetConstantWithIndexType(0)); + } else { + input_window_bounds.push_back( + raw_window_bounds[raw_window_multidim_idx]); + input_window_multidim.push_back( + raw_window_multidim[raw_window_multidim_idx]); + ++raw_window_multidim_idx; + } + } + DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape())); + + // Insert a 1 dimension at the end if index_vector_dim requests one. + Shape scatter_indices_shape = scatter_indices->shape(); + if (dim_numbers.index_vector_dim() == + ShapeUtil::Rank(scatter_indices_shape)) { + scatter_indices_shape.add_dimensions(1); + scatter_indices_shape.mutable_layout()->add_minor_to_major( + dim_numbers.index_vector_dim()); + } + llvm_ir::IrArray scatter_indices_reshaped = + GetIrArray(*scatter_indices, *scatter) + .CastToShape(scatter_indices_shape, &b_); + + // Now load the indices corresponding to the current window from + // scatter_indices. + llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim, + index.GetType()); + raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr); + llvm::Value* is_in_bounds = b_.getTrue(); + for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size(); + i != e; ++i) { + // Our index is stored along index_vector_dim, insert that into the lookup + // index into scatter_indices. + raw_scatter_index_index[dim_numbers.index_vector_dim()] = + raw_scatter_index_index.GetConstantWithIndexType(i); + + int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i); + llvm::Value* loaded_scatter_index = + scatter_indices_reshaped.EmitReadArrayElement(raw_scatter_index_index, + &b_, "scatter_index"); + // And add the index to our window index. This yields the output index. + llvm::Value* dim_offset = + Add(input_window_multidim[operand_dim], + IntCast(loaded_scatter_index, index.GetType(), + /*isSigned=*/true)); + input_window_multidim[operand_dim] = dim_offset; + + // Also do the bounds check now. + int64 max_index = operand->shape().dimensions(operand_dim) - + input_window_bounds[operand_dim] + 1; + // is_in_bounds = dim_offset >= 0 && dim_offset < dim_size-window_size+1 + // --> dim_offset u< dim_size-window_size+1 + is_in_bounds = + And(is_in_bounds, + ICmpULT(dim_offset, index.GetConstantWithIndexType(max_index))); + } + + llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse( + is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false); + llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_); + // All done, now just read from the calculated input from the window, and do + // an atomic store to the calculated location in the output. + llvm_ir::IrArray::Index input_window_index(input_window_multidim, + index.GetType()); + llvm::Value* input_address = + GetIrArray(*updates, *scatter).EmitArrayElementAddress(index, &b_); + llvm::Value* output_address = + GetIrArray(*scatter, *scatter) + .EmitArrayElementAddress(input_window_index, &b_); + return EmitAtomicOperationForNestedComputation( + *scatter->to_apply(), output_address, input_address); + }; + + // Launch a kernel that reads every element in the updates tensor. We could + // also do one kernel per window instead if bounds checks turn out to be a + // bottleneck. + thunks.push_back(BuildKernelThunk( + scatter, + /*implements_whole_instruction=*/operand_buffer == destination_buffer)); + + LaunchDimensions launch_dimensions = CalculateLaunchDimensions( + updates->shape(), ir_emitter_context_->device_description()); + UpdateLaunchDimensions(launch_dimensions, + static_cast<KernelThunk*>(thunks.back().get()), + ir_emitter_context_->llvm_module()); + + thunk_sequence_->emplace_back( + absl::make_unique<SequentialThunk>(std::move(thunks), scatter)); + return ParallelLoopEmitter(loop_body_emitter, updates->shape(), + launch_dimensions, &b_) + .EmitLoop(IrName(scatter), + GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(), + &b_)); +} + Status IrEmitterUnnested::HandleSelect(HloInstruction* select) { thunk_sequence_->push_back( BuildKernelThunk(select, /*implements_whole_instruction=*/true)); diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h index bd5db72051..2e36e7235b 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h @@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter { Status HandleInfeed(HloInstruction* xla_infeed) override; Status HandleOutfeed(HloInstruction* outfeed) override; Status HandleRng(HloInstruction* random) override; + Status HandleScatter(HloInstruction* scatter) override; Status HandleSelect(HloInstruction* select) override; Status HandleSort(HloInstruction* sort) override; Status HandleTupleSelect(HloInstruction* tuple_select) override; diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index ac6c2c5565..5409f65589 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -75,7 +75,6 @@ limitations under the License. #include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h" #include "tensorflow/compiler/xla/service/reduce_precision_insertion.h" #include "tensorflow/compiler/xla/service/reshape_mover.h" -#include "tensorflow/compiler/xla/service/scatter_expander.h" #include "tensorflow/compiler/xla/service/transpose_folding.h" #include "tensorflow/compiler/xla/service/tuple_simplifier.h" #include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h" @@ -176,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, // elimination has to come after that pass. pipeline.AddPass<ZeroSizedHloElimination>(); - pipeline.AddPass<ScatterExpander>(); - pass.AddPass<AlgebraicSimplifier>( /*is_layout_sensitive=*/false, [](const Shape&, const Shape&) { return false; }); diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index ad65b147c1..2cf5fc94ac 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -1908,6 +1908,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kRemainder: case HloOpcode::kReverse: case HloOpcode::kRoundNearestAfz: + case HloOpcode::kScatter: case HloOpcode::kSelect: case HloOpcode::kSelectAndScatter: case HloOpcode::kShiftLeft: @@ -1946,7 +1947,6 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kReduce: case HloOpcode::kReshape: case HloOpcode::kRng: - case HloOpcode::kScatter: case HloOpcode::kSend: case HloOpcode::kSendDone: case HloOpcode::kAfterAll: |