aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
diff options
context:
space:
mode:
authorGravatar Benjamin Kramer <kramerb@google.com>2018-07-10 14:25:10 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-07-10 14:28:32 -0700
commit0a805f8d9fdf2e16e0866586bdfb9a6151395a85 (patch)
treea78305f5d5fb75d3109c60eb56db04f3bffc9aae /tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
parent9b75e00697ab437d7a8db10584fc2d5c13ccf966 (diff)
[XLA:GPU] Implement outfeed
Infeed and outfeed manager are really similar but not quite the same, I'm open for ideas on how to factor them better. This has a much cleaner design for OutfeedManager than we have for InfeedManager, I'll look into cleaning up InfeedManager in a follow-up. PiperOrigin-RevId: 204012304
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc')
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc25
1 files changed, 24 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1002963dc2..59edba30e6 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -48,6 +48,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/ir_emitter_context.h"
#include "tensorflow/compiler/xla/service/gpu/kernel_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/memset_thunk.h"
+#include "tensorflow/compiler/xla/service/gpu/outfeed_thunk.h"
#include "tensorflow/compiler/xla/service/gpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/gpu/partition_assignment.h"
#include "tensorflow/compiler/xla/service/gpu/sequential_thunk.h"
@@ -2028,6 +2029,11 @@ Status IrEmitterUnnested::HandleInfeed(HloInstruction* infeed) {
return Status::OK();
}
+Status IrEmitterUnnested::HandleOutfeed(HloInstruction* outfeed) {
+ thunk_sequence_->emplace_back(BuildOutfeedThunk(outfeed));
+ return Status::OK();
+}
+
// Figures out how to access the buffers for all subshapes of hlo's operands and
// for hlo itself (i.e. all the buffers produced by HLO).
//
@@ -2275,7 +2281,7 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
ShapeTree<BufferAllocation::Slice> slices(inst->shape());
slices.ForEachMutableElement(
- [this, inst](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
*slice = ir_emitter_context_->buffer_assignment()
.GetUniqueSlice(inst, index)
.ConsumeValueOrDie();
@@ -2283,6 +2289,23 @@ std::unique_ptr<Thunk> IrEmitterUnnested::BuildInfeedThunk(
return MakeUnique<InfeedThunk>(slices, inst);
}
+std::unique_ptr<Thunk> IrEmitterUnnested::BuildOutfeedThunk(
+ const HloInstruction* inst) {
+ CHECK_EQ(HloOpcode::kOutfeed, inst->opcode());
+
+ ShapeTree<BufferAllocation::Slice> slices(inst->operand(0)->shape());
+ slices.ForEachMutableElement(
+ [&](const ShapeIndex& index, BufferAllocation::Slice* slice) {
+ auto status_or_slice =
+ ir_emitter_context_->buffer_assignment().GetUniqueSlice(
+ inst->operand(0), index);
+ if (status_or_slice.ok()) {
+ *slice = status_or_slice.ConsumeValueOrDie();
+ }
+ });
+ return MakeUnique<OutfeedThunk>(std::move(slices), inst);
+}
+
namespace {
double GetScalarConstantAsDouble(const Literal& literal) {
switch (literal.shape().element_type()) {