diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 97 |
1 files changed, 17 insertions, 80 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 674d3e3836..5107ac782d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -371,50 +371,20 @@ HloInstruction::CreateCrossReplicaSum(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( HloInstruction* operand, int64 channel_id) { - // Send instruction produces a tuple of {aliased operand, U32 context}. - Shape output_shape = ShapeUtil::MakeTupleShape( - {operand->shape(), ShapeUtil::MakeShape(U32, {})}); auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kSend, output_shape)); + WrapUnique(new HloInstruction(HloOpcode::kSend, ShapeUtil::MakeNil())); instruction->AppendOperand(operand); instruction->channel_id_ = channel_id; return instruction; } -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone( - HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kSend) - << "SendDone must take the context operand from Send"; - auto instruction = WrapUnique( - new HloInstruction(HloOpcode::kSendDone, ShapeUtil::MakeNil())); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; -} - /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv( const Shape& shape, int64 channel_id) { - // Recv instruction produces a tuple of {receive buffer, U32 context}. - Shape output_shape = - ShapeUtil::MakeTupleShape({shape, ShapeUtil::MakeShape(U32, {})}); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecv, output_shape)); + auto instruction = WrapUnique(new HloInstruction(HloOpcode::kRecv, shape)); instruction->channel_id_ = channel_id; return instruction; } -/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone( - HloInstruction* operand) { - CHECK(operand->opcode() == HloOpcode::kRecv) - << "RecvDone must take the context operand from Recv"; - Shape output_shape = ShapeUtil::GetTupleElementShape(operand->shape(), 0); - auto instruction = - WrapUnique(new HloInstruction(HloOpcode::kRecvDone, output_shape)); - instruction->AppendOperand(operand); - instruction->channel_id_ = operand->channel_id(); - return instruction; -} - /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse( const Shape& shape, HloInstruction* operand, tensorflow::gtl::ArraySlice<int64> dimensions) { @@ -938,9 +908,7 @@ RandomDistribution HloInstruction::random_distribution() const { bool HloInstruction::HasSideEffect() const { switch (opcode_) { case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kRecv: - case HloOpcode::kRecvDone: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: @@ -1196,9 +1164,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( new_operands[4], epsilon(), feature_index()); break; case HloOpcode::kRecv: - case HloOpcode::kRecvDone: case HloOpcode::kSend: - case HloOpcode::kSendDone: case HloOpcode::kTrace: LOG(FATAL) << "Not yet implemented, clone: " << HloOpcodeString(opcode_); } @@ -1591,10 +1557,8 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kSort: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: case HloOpcode::kSend: - case HloOpcode::kSendDone: + case HloOpcode::kRecv: return false; } } @@ -1886,13 +1850,12 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const { extra.push_back(StrCat("dimensions={", Join(dimensions(), ","), "}")); } if (window_ != nullptr) { - extra.push_back(StrCat("window={", window_util::ToString(*window_), "}")); + extra.push_back(window_util::ToString(*window_)); } if (padding_config_ != nullptr) { - extra.push_back( - StrCat("padding=", xla::PaddingConfigToString(*padding_config_))); + extra.push_back(StrCat("padding=", padding_config_->ShortDebugString())); } - if (opcode() == HloOpcode::kSlice) { + if (!slice_starts_.empty() && !slice_limits_.empty()) { std::vector<string> bounds; bounds.reserve(slice_starts_.size()); const bool omit_stride = @@ -1905,16 +1868,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const { } extra.push_back(StrCat("slice={", Join(bounds, ", "), "}")); } - if (opcode() == HloOpcode::kDynamicSlice) { - extra.push_back( - StrCat("dynamic_slice_sizes={", Join(dynamic_slice_sizes(), ","), "}")); - } - if (opcode() == HloOpcode::kBatchNormTraining || - opcode() == HloOpcode::kBatchNormInference || - opcode() == HloOpcode::kBatchNormGrad) { - extra.push_back(StrCat("epsilon=", epsilon())); - extra.push_back(StrCat("feature_index=", feature_index())); - } if (convolution_dimension_numbers_ != nullptr) { extra.push_back(ConvolutionDimensionNumbersToString()); @@ -1938,8 +1891,7 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const { }))); } - if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv || - opcode() == HloOpcode::kSendDone || opcode() == HloOpcode::kRecvDone) { + if (opcode() == HloOpcode::kSend || opcode() == HloOpcode::kRecv) { extra.push_back(StrCat("channel_id=", channel_id_)); } @@ -2119,10 +2071,8 @@ bool HloInstruction::IsFusable() const { case HloOpcode::kOutfeed: case HloOpcode::kParameter: case HloOpcode::kTrace: - case HloOpcode::kRecv: - case HloOpcode::kRecvDone: case HloOpcode::kSend: - case HloOpcode::kSendDone: + case HloOpcode::kRecv: return false; // Only fuse Rng if it is used once, otherwise the random numbers generated // will be different in each fusion. If it is the root (user count = 0) @@ -2329,14 +2279,10 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleCall(this); case HloOpcode::kCustomCall: return visitor->HandleCustomCall(this); - case HloOpcode::kRecv: - return visitor->HandleRecv(this); - case HloOpcode::kRecvDone: - return visitor->HandleRecvDone(this); case HloOpcode::kSend: return visitor->HandleSend(this); - case HloOpcode::kSendDone: - return visitor->HandleSendDone(this); + case HloOpcode::kRecv: + return visitor->HandleRecv(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2895,21 +2841,6 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind( return InvalidArgument("Unknown fusion kind: %s", kind_name.c_str()); } -string PaddingConfigToString(const PaddingConfig& padding) { - bool has_interior_padding = - std::any_of(padding.dimensions().begin(), padding.dimensions().end(), - [](const PaddingConfig::PaddingConfigDimension& dim) { - return dim.interior_padding() != 0; - }); - return Join( - padding.dimensions(), "x", - [&](string* out, const PaddingConfig::PaddingConfigDimension& dim) { - StrAppend( - out, dim.edge_padding_low(), "_", dim.edge_padding_high(), - has_interior_padding ? StrCat("_", dim.interior_padding()) : ""); - }); -} - std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } @@ -2925,7 +2856,13 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const { const auto append_dims = [&](const std::vector<string>& dims, const Shape& shape) { CHECK_EQ(dims.size(), ShapeUtil::Rank(shape)); - StrAppend(&result, Join(dims, "")); + for (int64 logical = 0; logical < dims.size(); ++logical) { + int64 physical = logical; + if (!shape.layout().minor_to_major().empty()) { + physical = LayoutUtil::Major(shape.layout(), logical); + } + result += dims[physical]; + } }; // lhs_dims[i] is the symbol of the logical dimension i for the lhs |