aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-10-09 13:32:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-09 13:40:43 -0700
commit5d9a7fdf4f02c2db487a03e7ad2d520f8847c4e3 (patch)
treea77d90f9328b7e0e859a15ab3b5d765774954b5a
parent9989788be25c846d087ac70b76cf78759a209a3e (diff)
[XLA:GPU] Add an implementation of scatter for GPU
This simple has a kernel that runs on every element of the updates tensor, figure out the right indices to perform the update, and applies it with an atomic operation. Currently we emit a CAS for plain (i.e. non-add) updates, which is inefficient. Also TuplePointsToAnalysis doesn't know that it should alias the operand and output buffers of a scatter, which would avoid a copy. PiperOrigin-RevId: 216412467
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc141
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h1
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc3
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc2
5 files changed, 143 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 350fd32537..0144d59097 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -705,7 +705,6 @@ cc_library(
"//tensorflow/compiler/xla/service:llvm_compiler",
"//tensorflow/compiler/xla/service:reduce_precision_insertion",
"//tensorflow/compiler/xla/service:reshape_mover",
- "//tensorflow/compiler/xla/service:scatter_expander",
"//tensorflow/compiler/xla/service:transpose_folding",
"//tensorflow/compiler/xla/service:tuple_simplifier",
"//tensorflow/compiler/xla/service:while_loop_constant_sinking",
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index c792dd2ddb..bef7a55301 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -1958,6 +1958,147 @@ Status IrEmitterUnnested::HandleRng(HloInstruction* rng) {
return Status::OK();
}
+Status IrEmitterUnnested::HandleScatter(HloInstruction* scatter) {
+ const HloInstruction* operand = scatter->operand(0);
+ const HloInstruction* scatter_indices = scatter->operand(1);
+ const HloInstruction* updates = scatter->operand(2);
+ const ScatterDimensionNumbers& dim_numbers =
+ scatter->scatter_dimension_numbers();
+ CHECK(ShapeUtil::Equal(scatter->shape(), operand->shape()));
+
+ std::vector<std::unique_ptr<Thunk>> thunks;
+
+ // Copy the operand into the output if it's not the same buffer already.
+ auto operand_buffer = GetAllocationSlice(*operand);
+ auto destination_buffer = GetAllocationSlice(*scatter);
+ if (operand_buffer != destination_buffer) {
+ thunks.push_back(absl::make_unique<DeviceToDeviceCopyThunk>(
+ /*source_address=*/operand_buffer,
+ /*destination_buffer=*/destination_buffer,
+ /*mem_size=*/ShapeUtil::ByteSizeOf(operand->shape()), scatter));
+ }
+
+ auto loop_body_emitter = [&](const IrArray::Index& index) -> Status {
+ std::vector<llvm::Value*> raw_window_multidim;
+ std::vector<llvm::Value*> input_scatter_multidim;
+ std::vector<int64> raw_window_bounds;
+
+ // Partition the index into window indices and scatter indices.
+ for (int64 i = 0, e = index.size(); i != e; ++i) {
+ // For window indices also remember the window size, this comes in handy
+ // later.
+ if (absl::c_binary_search(dim_numbers.update_window_dims(), i)) {
+ raw_window_multidim.push_back(index[i]);
+ raw_window_bounds.push_back(updates->shape().dimensions(i));
+ } else {
+ input_scatter_multidim.push_back(index[i]);
+ }
+ }
+ DCHECK_EQ(raw_window_multidim.size(),
+ dim_numbers.update_window_dims_size());
+
+ // Apply inserted_window_dims to the window dimensions.
+ int64 raw_window_multidim_idx = 0;
+ std::vector<llvm::Value*> input_window_multidim;
+ std::vector<int64> input_window_bounds;
+ for (int64 i = 0, e = ShapeUtil::Rank(operand->shape()); i != e; ++i) {
+ if (absl::c_binary_search(dim_numbers.inserted_window_dims(), i)) {
+ input_window_bounds.push_back(1); // Trivial dimension.
+ input_window_multidim.push_back(index.GetConstantWithIndexType(0));
+ } else {
+ input_window_bounds.push_back(
+ raw_window_bounds[raw_window_multidim_idx]);
+ input_window_multidim.push_back(
+ raw_window_multidim[raw_window_multidim_idx]);
+ ++raw_window_multidim_idx;
+ }
+ }
+ DCHECK_EQ(input_window_multidim.size(), ShapeUtil::Rank(operand->shape()));
+
+ // Insert a 1 dimension at the end if index_vector_dim requests one.
+ Shape scatter_indices_shape = scatter_indices->shape();
+ if (dim_numbers.index_vector_dim() ==
+ ShapeUtil::Rank(scatter_indices_shape)) {
+ scatter_indices_shape.add_dimensions(1);
+ scatter_indices_shape.mutable_layout()->add_minor_to_major(
+ dim_numbers.index_vector_dim());
+ }
+ llvm_ir::IrArray scatter_indices_reshaped =
+ GetIrArray(*scatter_indices, *scatter)
+ .CastToShape(scatter_indices_shape, &b_);
+
+ // Now load the indices corresponding to the current window from
+ // scatter_indices.
+ llvm_ir::IrArray::Index raw_scatter_index_index(input_scatter_multidim,
+ index.GetType());
+ raw_scatter_index_index.InsertAt(dim_numbers.index_vector_dim(), nullptr);
+ llvm::Value* is_in_bounds = b_.getTrue();
+ for (int64 i = 0, e = dim_numbers.scatter_dims_to_operand_dims_size();
+ i != e; ++i) {
+ // Our index is stored along index_vector_dim, insert that into the lookup
+ // index into scatter_indices.
+ raw_scatter_index_index[dim_numbers.index_vector_dim()] =
+ raw_scatter_index_index.GetConstantWithIndexType(i);
+
+ int64 operand_dim = dim_numbers.scatter_dims_to_operand_dims(i);
+ llvm::Value* loaded_scatter_index =
+ scatter_indices_reshaped.EmitReadArrayElement(raw_scatter_index_index,
+ &b_, "scatter_index");
+ // And add the index to our window index. This yields the output index.
+ llvm::Value* dim_offset =
+ Add(input_window_multidim[operand_dim],
+ IntCast(loaded_scatter_index, index.GetType(),
+ /*isSigned=*/true));
+ input_window_multidim[operand_dim] = dim_offset;
+
+ // Also do the bounds check now.
+ int64 max_index = operand->shape().dimensions(operand_dim) -
+ input_window_bounds[operand_dim] + 1;
+ // is_in_bounds = dim_offset >= 0 && dim_offset < dim_size-window_size+1
+ // --> dim_offset u< dim_size-window_size+1
+ is_in_bounds =
+ And(is_in_bounds,
+ ICmpULT(dim_offset, index.GetConstantWithIndexType(max_index)));
+ }
+
+ llvm_ir::LlvmIfData if_window_in_bounds_data = llvm_ir::EmitIfThenElse(
+ is_in_bounds, "scatter.in_bounds", &b_, /*emit_else=*/false);
+ llvm_ir::SetToFirstInsertPoint(if_window_in_bounds_data.true_block, &b_);
+ // All done, now just read from the calculated input from the window, and do
+ // an atomic store to the calculated location in the output.
+ llvm_ir::IrArray::Index input_window_index(input_window_multidim,
+ index.GetType());
+ llvm::Value* input_address =
+ GetIrArray(*updates, *scatter).EmitArrayElementAddress(index, &b_);
+ llvm::Value* output_address =
+ GetIrArray(*scatter, *scatter)
+ .EmitArrayElementAddress(input_window_index, &b_);
+ return EmitAtomicOperationForNestedComputation(
+ *scatter->to_apply(), output_address, input_address);
+ };
+
+ // Launch a kernel that reads every element in the updates tensor. We could
+ // also do one kernel per window instead if bounds checks turn out to be a
+ // bottleneck.
+ thunks.push_back(BuildKernelThunk(
+ scatter,
+ /*implements_whole_instruction=*/operand_buffer == destination_buffer));
+
+ LaunchDimensions launch_dimensions = CalculateLaunchDimensions(
+ updates->shape(), ir_emitter_context_->device_description());
+ UpdateLaunchDimensions(launch_dimensions,
+ static_cast<KernelThunk*>(thunks.back().get()),
+ ir_emitter_context_->llvm_module());
+
+ thunk_sequence_->emplace_back(
+ absl::make_unique<SequentialThunk>(std::move(thunks), scatter));
+ return ParallelLoopEmitter(loop_body_emitter, updates->shape(),
+ launch_dimensions, &b_)
+ .EmitLoop(IrName(scatter),
+ GetIndexTypeForKernel(scatter, launch_dimensions.launch_bound(),
+ &b_));
+}
+
Status IrEmitterUnnested::HandleSelect(HloInstruction* select) {
thunk_sequence_->push_back(
BuildKernelThunk(select, /*implements_whole_instruction=*/true));
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
index bd5db72051..2e36e7235b 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.h
@@ -76,6 +76,7 @@ class IrEmitterUnnested : public IrEmitter {
Status HandleInfeed(HloInstruction* xla_infeed) override;
Status HandleOutfeed(HloInstruction* outfeed) override;
Status HandleRng(HloInstruction* random) override;
+ Status HandleScatter(HloInstruction* scatter) override;
Status HandleSelect(HloInstruction* select) override;
Status HandleSort(HloInstruction* sort) override;
Status HandleTupleSelect(HloInstruction* tuple_select) override;
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index ac6c2c5565..5409f65589 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -75,7 +75,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/reduce_precision_insertion.h"
#include "tensorflow/compiler/xla/service/reshape_mover.h"
-#include "tensorflow/compiler/xla/service/scatter_expander.h"
#include "tensorflow/compiler/xla/service/transpose_folding.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/service/while_loop_constant_sinking.h"
@@ -176,8 +175,6 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec,
// elimination has to come after that pass.
pipeline.AddPass<ZeroSizedHloElimination>();
- pipeline.AddPass<ScatterExpander>();
-
pass.AddPass<AlgebraicSimplifier>(
/*is_layout_sensitive=*/false,
[](const Shape&, const Shape&) { return false; });
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index ad65b147c1..2cf5fc94ac 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -1908,6 +1908,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kRemainder:
case HloOpcode::kReverse:
case HloOpcode::kRoundNearestAfz:
+ case HloOpcode::kScatter:
case HloOpcode::kSelect:
case HloOpcode::kSelectAndScatter:
case HloOpcode::kShiftLeft:
@@ -1946,7 +1947,6 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kReduce:
case HloOpcode::kReshape:
case HloOpcode::kRng:
- case HloOpcode::kScatter:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kAfterAll: