aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc97
1 files changed, 17 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 674d3e3836..5107ac782d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -371,50 +371,20 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend(
HloInstruction* operand, int64 channel_id) {
- // Send instruction produces a tuple of {aliased operand, U32 context}.
- Shape output_shape = ShapeUtil::MakeTupleShape(
- {operand->shape(), ShapeUtil::MakeShape(U32, {})});
auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape));
+ WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil()));
instruction->AppendOperand(operand);
instruction->channel_id_ = channel_id;
return instruction;
}
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone(
- HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kSend)
- << "SendDone must take the context operand from Send";
- auto instruction = WrapUnique(
- new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil()));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
-}
-
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv(
const Shape& shape, int64 channel_id) {
- // Recv instruction produces a tuple of {receive buffer, U32 context}.
- Shape output_shape =
- ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})});
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape));
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape));
instruction->channel_id_ = channel_id;
return instruction;
}
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone(
- HloInstruction* operand) {
- CHECK(operand->opcode() == HloOpcode::kRecv)
- << "RecvDone must take the context operand from Recv";
- Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0);
- auto instruction =
- WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape));
- instruction->AppendOperand(operand);
- instruction->channel_id_ = operand->channel_id();
- return instruction;
-}
-
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
@@ -938,9 +908,7 @@ RandomDistribution HloInstruction::random_distribution() const {
bool HloInstruction::HasSideEffect() const {
switch (opcode_) {
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
@@ -1196,9 +1164,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
new_operands[4], epsilon(), feature_index());
break;
case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
case HloOpcode::kTrace:
LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_);
}
@@ -1591,10 +1557,8 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kSort:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
return false;
}
}
@@ -1886,13 +1850,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}"));
}
if (window_ != nullptr) {
- extra.push_back(StrCat("window={", window_util::ToString(*window_), "}"));
+ extra.push_back(window_util::ToString(*window_));
}
if (padding_config_ != nullptr) {
- extra.push_back(
- StrCat("padding=", xla::PaddingConfigToString(*padding_config_)));
+ extra.push_back(StrCat("padding=", padding_config_->ShortDebugString()));
}
- if (opcode() == HloOpcode::kSlice) {
+ if (!slice_starts_.empty() && !slice_limits_.empty()) {
std::vector<string> bounds;
bounds.reserve(slice_starts_.size());
const bool omit_stride =
@@ -1905,16 +1868,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
}
extra.push_back(StrCat("slice={", Join(bounds, ", "), "}"));
}
- if (opcode() == HloOpcode::kDynamicSlice) {
- extra.push_back(
- StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}"));
- }
- if (opcode() == HloOpcode::kBatchNormTraining ||
- opcode() == HloOpcode::kBatchNormInference ||
- opcode() == HloOpcode::kBatchNormGrad) {
- extra.push_back(StrCat("epsilon=", epsilon()));
- extra.push_back(StrCat("feature_index=", feature_index()));
- }
if (convolution_dimension_numbers_ != nullptr) {
extra.push_back(ConvolutionDimensionNumbersToString());
@@ -1938,8 +1891,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
})));
}
- if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv ||
- opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) {
+ if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) {
extra.push_back(StrCat("channel_id=", channel_id_));
}
@@ -2119,10 +2071,8 @@ bool HloInstruction::IsFusable() const {
case HloOpcode::kOutfeed:
case HloOpcode::kParameter:
case HloOpcode::kTrace:
- case HloOpcode::kRecv:
- case HloOpcode::kRecvDone:
case HloOpcode::kSend:
- case HloOpcode::kSendDone:
+ case HloOpcode::kRecv:
return false;
// Only fuse Rng if it is used once, otherwise the random numbers generated
// will be different in each fusion. If it is the root (user count = 0)
@@ -2329,14 +2279,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleCall(this);
case HloOpcode::kCustomCall:
return visitor->HandleCustomCall(this);
- case HloOpcode::kRecv:
- return visitor->HandleRecv(this);
- case HloOpcode::kRecvDone:
- return visitor->HandleRecvDone(this);
case HloOpcode::kSend:
return visitor->HandleSend(this);
- case HloOpcode::kSendDone:
- return visitor->HandleSendDone(this);
+ case HloOpcode::kRecv:
+ return visitor->HandleRecv(this);
// These opcodes are not handled here.
case HloOpcode::kTrace:
@@ -2895,21 +2841,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str());
}
-string PaddingConfigToString(const PaddingConfig& padding) {
- bool has_interior_padding =
- std::any_of(padding.dimensions().begin(), padding.dimensions().end(),
- [](const PaddingConfig::PaddingConfigDimension& dim) {
- return dim.interior_padding() != 0;
- });
- return Join(
- padding.dimensions(), "x",
- [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) {
- StrAppend(
- out, dim.edge_padding_low(), "_", dim.edge_padding_high(),
- has_interior_padding ? StrCat("_", dim.interior_padding()) : "");
- });
-}
-
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) {
return os << ToString(kind);
}
@@ -2925,7 +2856,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
const auto append_dims = [&](const std::vector<string>& dims,
const Shape& shape) {
CHECK_EQ(dims.size(), ShapeUtil::Rank(shape));
- StrAppend(&result, Join(dims, ""));
+ for (int64 logical = 0; logical < dims.size(); ++logical) {
+ int64 physical = logical;
+ if (!shape.layout().minor_to_major().empty()) {
+ physical = LayoutUtil::Major(shape.layout(), logical);
+ }
+ result += dims[physical];
+ }
};
// lhs_dims[i] is the symbol of the logical dimension i for the lhs