aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/layout_assignment.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/layout_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc105
1 files changed, 84 insertions, 21 deletions
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 36fdfa868d..9705687b00 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -30,10 +30,12 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
+#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
@@ -59,7 +61,6 @@ namespace xla {
// anonymous namespace, instead of three or four spread all over this file.
namespace {
-
} // namespace
std::ostream& operator<<(std::ostream& out,
@@ -113,14 +114,18 @@ LayoutConstraints::LayoutConstraints(
HloComputation* computation)
: points_to_analysis_(points_to_analysis), computation_(computation) {
// Gather all array-shaped logical buffers into unconstrained_buffer_ids.
- for (LogicalBuffer::Id id = 0; id < points_to_analysis_.num_logical_buffers();
- id++) {
- auto& buffer = points_to_analysis_.logical_buffer(id);
- // The points to analysis is computed per module, restrict constraints to
- // array buffers in this computation.
- if (buffer.IsArray() && buffer.instruction()->parent() == computation) {
- unconstrained_buffer_ids_.insert(buffer.id());
- }
+ for (HloInstruction* inst : computation_->instructions()) {
+ points_to_analysis_.GetPointsToSet(inst).ForEachElement(
+ [&](const ShapeIndex&, const PointsToSet::BufferList& buffers) {
+ for (const LogicalBuffer* buffer : buffers) {
+ // The points to analysis is computed per module, restrict
+ // constraints to array buffers in this computation.
+ if (buffer->IsArray() &&
+ buffer->instruction()->parent() == computation) {
+ unconstrained_buffer_ids_.insert(buffer->id());
+ }
+ }
+ });
}
}
@@ -392,6 +397,43 @@ string LayoutConstraints::ToString() const {
return output;
}
+namespace {
+
+bool IsHostSendRecv(const HloInstruction* instruction) {
+ const HloSendRecvInstruction* send_recv_instr =
+ DynCast<HloSendRecvInstruction>(instruction);
+ return send_recv_instr != nullptr && send_recv_instr->is_host_transfer();
+}
+
+} // namespace
+
+Status LayoutAssignment::BuildHostChannelConstraints(
+ HloComputation* computation) {
+ for (auto* instruction : computation->instructions()) {
+ const HloSendRecvInstruction* send_recv_instr =
+ DynCast<HloSendRecvInstruction>(instruction);
+ if (send_recv_instr == nullptr || !send_recv_instr->is_host_transfer()) {
+ continue;
+ }
+
+ // For host transfers the Send and Recv instruction carry the layout.
+ if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kRecv) {
+ const Shape& data_shape =
+ ShapeUtil::GetTupleElementShape(send_recv_instr->shape(), 0);
+ TF_RET_CHECK(ShapeUtil::IsArray(data_shape));
+ TF_RET_CHECK(LayoutUtil::HasLayout(data_shape));
+ const Layout* prev_layout = host_channel_constraints_.ConstrainChannel(
+ send_recv_instr->channel_id(), data_shape.layout());
+ TF_RET_CHECK(prev_layout == nullptr)
+ << "Cannot constrain host transfer layout as it was set to "
+ << LayoutUtil::HumanString(*prev_layout) << ": "
+ << send_recv_instr->ToString();
+ }
+ }
+ return Status::OK();
+}
+
Status LayoutAssignment::AddMandatoryConstraints(
const ComputationLayout* computation_layout,
ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
@@ -399,6 +441,11 @@ Status LayoutAssignment::AddMandatoryConstraints(
VLOG(3) << "Adding mandatory layout constraints to computation "
<< computation->name();
+ auto get_channel_constraints = [&](const HloInstruction* instruction) {
+ return IsHostSendRecv(instruction) ? &host_channel_constraints_
+ : channel_constraints;
+ };
+
// Constrain layouts of instructions which define values with pre-existing
// layouts.
for (auto* instruction : computation->instructions()) {
@@ -435,18 +482,21 @@ Status LayoutAssignment::AddMandatoryConstraints(
if (instruction->opcode() == HloOpcode::kSend ||
instruction->opcode() == HloOpcode::kRecv) {
- CHECK(channel_constraints)
+ CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id();
- if (!channel_constraints->IsChannelConstrained(channel_id)) {
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(channel_id)) {
continue;
}
if (instruction->opcode() == HloOpcode::kSend) {
// TODO(b/68493863): Change to use SetOperandLayout().
const Shape send_buffer_shape = instruction->operand(0)->shape();
TF_RET_CHECK(ShapeUtil::IsArray(send_buffer_shape));
- Shape new_buffer_shape = channel_constraints->LayoutShapeForChannel(
- send_buffer_shape, instruction->channel_id());
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(send_buffer_shape,
+ instruction->channel_id());
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
new_buffer_shape, instruction->operand(0)));
} else {
@@ -457,8 +507,9 @@ Status LayoutAssignment::AddMandatoryConstraints(
const LogicalBuffer* buffer,
constraints->points_to_analysis().GetBufferDefinedAt(instruction,
{0}));
- Shape new_shape = channel_constraints->LayoutShapeForChannel(
- recv_buffer_shape, instruction->channel_id());
+ Shape new_shape = get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(
+ recv_buffer_shape, instruction->channel_id());
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
@@ -1535,6 +1586,10 @@ Status LayoutAssignment::RunOnComputation(
ChannelLayoutConstraints* channel_constraints) {
VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
<< ")";
+
+ // Must be run before clearing layouts.
+ TF_RETURN_IF_ERROR(BuildHostChannelConstraints(computation));
+
TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
if (computation_layout != nullptr) {
auto it = computation_layouts_.find(computation);
@@ -1624,13 +1679,20 @@ Status LayoutAssignment::RunOnComputation(
Status LayoutAssignment::ConstrainChannelLayouts(
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints) {
+ auto get_channel_constraints = [&](const HloInstruction* instruction) {
+ return IsHostSendRecv(instruction) ? &host_channel_constraints_
+ : channel_constraints;
+ };
// We go through the kRecvDone before. These must either impose their layout,
- // of find a matching one already existing (ConstrainChannel() returns
+ // or find a matching one already existing (ConstrainChannel() returns
// nullptr).
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kRecvDone) {
- const Layout* layout = channel_constraints->ConstrainChannel(
- instruction->channel_id(), instruction->shape().layout());
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(
+ instruction->channel_id(),
+ ShapeUtil::GetSubshape(instruction->shape(), {0}).layout());
TF_RET_CHECK(layout == nullptr)
<< instruction->ToString()
<< " cannot constrain layout as it was set to "
@@ -1643,11 +1705,12 @@ Status LayoutAssignment::ConstrainChannelLayouts(
for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
if (instruction->opcode() == HloOpcode::kSend) {
HloInstruction* operand = instruction->mutable_operand(0);
- const Layout* layout = channel_constraints->ConstrainChannel(
- instruction->channel_id(), operand->shape().layout());
+ const Layout* layout = get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->channel_id(),
+ operand->shape().layout());
if (layout != nullptr) {
// We found an already constrained layout which does not match the one
- // the kSend wants to impose. Eitehr add a new kCopy, or use the
+ // the kSend wants to impose. Either add a new kCopy, or use the
// existing one to marshal the correct shape.
Shape shape = operand->shape();
*shape.mutable_layout() = *layout;