aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Bixia Zheng <bixia@google.com>2018-09-06 15:54:55 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-06 15:59:08 -0700
commit1e9cbcc208a85c467b3db7fbbcba681aa012c607 (patch)
treee1870917555b9f939b5f279e3f3d10415e36e5bd /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parenta07440916c1e603b3634ba5e4844597b967d5e55 (diff)
[XLA:GPU] Refactor some code for fusion output handling.
Move routine ConstructIrArrayForOutputs to class IrEmitter so that it can be used in classes IrEmitterNested and IrEmitterUnnested. Move the code that stores the address of each individual output of a multiple output fusion to the tuple buffer of the fusion to an overload version of routine llvm_ir::EmitTuple so that we can reduce code duplication. PiperOrigin-RevId: 211884483
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc54
1 files changed, 14 insertions, 40 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 389a98facb..0c7623fd79 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -2819,10 +2819,7 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
}
// For multioutput fusion, we need to emit each operand and the root.
- std::vector<IrArray> output_arrays;
- for (int64 i = 0; i < ShapeUtil::TupleElementCount(hlo.shape()); ++i) {
- output_arrays.push_back(GetIrArray(hlo, hlo, {i}));
- }
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(hlo);
TF_RETURN_IF_ERROR(
ParallelLoopEmitter(element_generator, output_arrays, launch_dimensions,
&b_, unroll_factor)
@@ -2830,12 +2827,9 @@ Status IrEmitterUnnested::EmitTargetElementLoopInThunk(
GetIndexTypeForKernel(
&hlo, launch_dimensions.launch_bound(), &b_)));
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
b_.SetInsertPoint(b_.GetInsertBlock()->getTerminator());
- llvm_ir::EmitTuple(GetIrArray(hlo, hlo), tuple_operand_ptrs, &b_, module_);
+ llvm_ir::EmitTuple(GetIrArray(hlo, hlo), output_arrays, &b_, module_);
+
return Status::OK();
}
@@ -2847,29 +2841,14 @@ Status IrEmitterUnnested::EmitTargetElementLoop(
static_cast<KernelThunk*>(LastThunk()));
}
-int IrEmitterUnnested::ConstructIrArrayForOutputs(
- const HloInstruction& hlo, std::vector<IrArray>* output_arrays) {
- int64 num_outputs = 1;
- if (hlo.IsMultiOutputFusion()) {
- num_outputs = ShapeUtil::TupleElementCount(hlo.shape());
- output_arrays->reserve(num_outputs);
- for (int64 i = 0; i < num_outputs; ++i) {
- output_arrays->push_back(GetIrArray(hlo, hlo, {i}));
- }
- } else {
- output_arrays->push_back(GetIrArray(hlo, hlo));
- }
- return num_outputs;
-}
-
-int IrEmitterUnnested::ConstructIrArrayForInputs(
- const HloInstruction& hlo, std::vector<IrArray>* param_arrays) {
- int64 num_params = hlo.operands().size();
- param_arrays->reserve(num_params);
+std::vector<IrArray> IrEmitterUnnested::ConstructIrArrayForInputs(
+ const HloInstruction& hlo) {
+ std::vector<IrArray> param_arrays;
+ param_arrays.reserve(hlo.operands().size());
for (const HloInstruction* param : hlo.operands()) {
- param_arrays->push_back(GetIrArray(*param, hlo));
+ param_arrays.push_back(GetIrArray(*param, hlo));
}
- return num_params;
+ return param_arrays;
}
int IrEmitterUnnested::ConstructOutputReducedShapeAndCastOutputIrArrayToShape(
@@ -3050,10 +3029,10 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
constexpr int64 kThreadsPerTile = kTileSize * kNumRows;
// Construct IrArrays for the inputs and outputs.
- std::vector<IrArray> output_arrays;
- int64 num_outputs = ConstructIrArrayForOutputs(*hlo, &output_arrays);
- std::vector<IrArray> param_arrays;
- int64 num_params = ConstructIrArrayForInputs(*hlo, &param_arrays);
+ std::vector<IrArray> output_arrays = ConstructIrArrayForOutputs(*hlo);
+ int64 num_outputs = output_arrays.size();
+ std::vector<IrArray> param_arrays = ConstructIrArrayForInputs(*hlo);
+ int64 num_params = param_arrays.size();
// Allocate shared memory buffers to store the tiled inputs.
std::vector<llvm::Value*> param_shmem_buffers(num_params, nullptr);
@@ -3251,12 +3230,7 @@ LaunchDimensions IrEmitterUnnested::EmitHlo021Tile(
// For multioutput fusion, emit a tuple with all the individual outputs.
if (hlo->IsMultiOutputFusion()) {
- std::vector<llvm::Value*> tuple_operand_ptrs;
- for (int64 i = 0; i < output_arrays.size(); ++i) {
- tuple_operand_ptrs.push_back(output_arrays[i].GetBasePointer());
- }
- llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), tuple_operand_ptrs, &b_,
- module_);
+ llvm_ir::EmitTuple(GetIrArray(*hlo, *hlo), output_arrays, &b_, module_);
}
return launch_dimensions;