diff options
author | Benjamin Kramer <kramerb@google.com> | 2018-07-10 14:25:10 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-07-10 14:28:32 -0700 |
commit | 0a805f8d9fdf2e16e0866586bdfb9a6151395a85 (patch) | |
tree | a78305f5d5fb75d3109c60eb56db04f3bffc9aae /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | |
parent | 9b75e00697ab437d7a8db10584fc2d5c13ccf966 (diff) |
[XLA:GPU] Implement outfeed
Infeed and outfeed manager are really similar but not quite the same, I'm open
for ideas on how to factor them better. This has a much cleaner design for
OutfeedManager than we have for InfeedManager, I'll look into cleaning up
InfeedManager in a follow-up.
PiperOrigin-RevId: 204012304
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc | 25 |
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc index 1002963dc2..59edba30e6 100644 --- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc +++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc @@ -48,6 +48,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h" #include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h" #include "tensorflow/compiler/xla/service/gpu/memset_thunk.h" +#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h" #include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h" #include "tensorflow/compiler/xla/service/gpu/partition_assignment.h" #include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h" @@ -2028,6 +2029,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) { return Status::OK(); } +Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) { + thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed)); + return Status::OK(); +} + // Figures out how to access the buffers for all subshapes of hlo's operands and // for hlo itself (i.e. all the buffers produced by HLO). // @@ -2275,7 +2281,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( ShapeTree<BufferAllocation::Slice> slices(inst->shape()); slices.ForEachMutableElement( - [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) { + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { *slice = ir_emitter_context_->buffer_assignment() .GetUniqueSlice(inst, index) .ConsumeValueOrDie(); @@ -2283,6 +2289,23 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk( return MakeUnique<InfeedThunk>(slices, inst); } +std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk( + const HloInstruction* inst) { + CHECK_EQ(HloOpcode::kOutfeed, inst->opcode()); + + ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape()); + slices.ForEachMutableElement( + [&](const ShapeIndex& index, BufferAllocation::Slice* slice) { + auto status_or_slice = + ir_emitter_context_->buffer_assignment().GetUniqueSlice( + inst->operand(0), index); + if (status_or_slice.ok()) { + *slice = status_or_slice.ConsumeValueOrDie(); + } + }); + return MakeUnique<OutfeedThunk>(std::move(slices), inst); +} + namespace { double GetScalarConstantAsDouble(const Literal& literal) { switch (literal.shape().element_type()) { |