diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 238 |
1 files changed, 126 insertions, 112 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5aaeec802f..8b9bdd2f46 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -22,7 +22,7 @@ limitations under the License. #include <utility> #include "tensorflow/compiler/xla/layout_util.h" -#include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/literal.h" #include "tensorflow/compiler/xla/protobuf_util.h" #include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h" @@ -112,29 +112,30 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; } case HloOpcode::kSend: - TF_RET_CHECK(proto.operand_ids_size() == 1) - << "Send instruction should have 1 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Send instruction should have 2 operand but sees " << proto.operand_ids_size(); - instruction = CreateSend(operands(0), proto.channel_id()); + instruction = CreateSend(operands(0), operands(1), proto.channel_id(), + proto.is_host_transfer()); break; case HloOpcode::kSendDone: TF_RET_CHECK(proto.operand_ids_size() == 1) << "SendDone instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateSendDone(operands(0)); + instruction = CreateSendDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kRecv: - TF_RET_CHECK(proto.operand_ids_size() == 0) - << "Recv instruction should have 0 operand but sees " + TF_RET_CHECK(proto.operand_ids_size() == 1) + << "Recv instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = - CreateRecv(proto.shape().tuple_shapes(0), proto.channel_id()); + instruction = CreateRecv(proto.shape().tuple_shapes(0), operands(0), + proto.channel_id(), proto.is_host_transfer()); break; case HloOpcode::kRecvDone: TF_RET_CHECK(proto.operand_ids_size() == 1) << "RecvDone instruction should have 1 operand but sees " << proto.operand_ids_size(); - instruction = CreateRecvDone(operands(0)); + instruction = CreateRecvDone(operands(0), proto.is_host_transfer()); break; case HloOpcode::kReverse: TF_RET_CHECK(proto.operand_ids_size() == 1) @@ -163,6 +164,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( proto.dimensions().end()), computations(0)); break; + case HloOpcode::kSort: { + TF_RET_CHECK(proto.operand_ids_size() == 1 || + proto.operand_ids_size() == 2) + << "Sort instruction should have 1 or 2 operands but has " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.dimensions().size() == 1) + << "Sort instruction should have 1 dimension"; + HloInstruction* keys = operands(0); + HloInstruction* values = + proto.operand_ids_size() == 2 ? operands(1) : nullptr; + instruction = + CreateSort(proto.shape(), proto.dimensions(0), keys, values); + break; + } case HloOpcode::kTranspose: TF_RET_CHECK(proto.operand_ids_size() == 1) << "Transpose instruction should have 1 operand but sees " @@ -271,7 +286,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( // converted to take tokens. instruction = CreateInfeed(data_shape, proto.infeed_config()); } else { - CHECK_EQ(proto.operand_ids_size(), 2); + CHECK_EQ(proto.operand_ids_size(), 1); instruction = CreateInfeed(data_shape, operands(0), proto.infeed_config()); } @@ -372,6 +387,23 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( slice_sizes); break; } + case HloOpcode::kGather: { + TF_RET_CHECK(proto.operand_ids_size() == 2) + << "Gather instruction should have 2 operands but sees " + << proto.operand_ids_size(); + TF_RET_CHECK(proto.has_gather_dimension_numbers()) + << "Gather instruction should have GatherDimensionNumbers set."; + std::unique_ptr<GatherDimensionNumbers> gather_dimension_numbers = + MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); + std::vector<int64> gather_window_bounds; + for (int64 bound : proto.gather_window_bounds()) { + gather_window_bounds.push_back(bound); + } + instruction = + CreateGather(proto.shape(), operands(0), operands(1), + *gather_dimension_numbers, gather_window_bounds); + break; + } default: { instruction = WrapUnique(new HloInstruction(opcode, proto.shape())); for (const int64 operand_id : proto.operand_ids()) { @@ -413,13 +445,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->set_sharding(sharding); } - if (proto.has_gather_dimension_numbers()) { - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(proto.gather_dimension_numbers()); - } - for (int64 bound : proto.gather_window_bounds()) { - instruction->gather_window_bounds_.push_back(bound); - } return std::move(instruction); } @@ -438,6 +463,11 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( return MakeUnique<HloConstantInstruction>(std::move(literal)); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateIota( + const Shape& shape) { + return WrapUnique(new HloInstruction(HloOpcode::kIota, shape)); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateGetTupleElement(const Shape& shape, HloInstruction* operand, int64 index) { @@ -489,7 +519,6 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: break; default: @@ -542,8 +571,9 @@ HloInstruction::CreateGetTupleElement(const Shape& shape, // Only certain opcodes are supported with CreateTernary: opcodes of ternary // instructions with no auxiliary fields. switch (opcode) { - case (HloOpcode::kClamp): - case (HloOpcode::kSelect): + case HloOpcode::kClamp: + case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: break; default: LOG(FATAL) << "Invalid ternary instruction opcode " @@ -651,29 +681,33 @@ HloInstruction::CreateCrossReplicaSum( } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSend( - HloInstruction* operand, int64 channel_id) { - return MakeUnique<HloSendInstruction>(operand, channel_id); + HloInstruction* operand, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique<HloSendInstruction>(operand, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSendDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto send_operand = DynCast<HloSendInstruction>(operand); CHECK(send_operand != nullptr) << "SendDone must take the context operand from Send"; - return MakeUnique<HloSendDoneInstruction>(send_operand); + return MakeUnique<HloSendDoneInstruction>(send_operand, is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecv( - const Shape& shape, int64 channel_id) { - return MakeUnique<HloRecvInstruction>(shape, channel_id); + const Shape& shape, HloInstruction* token, int64 channel_id, + bool is_host_transfer) { + return MakeUnique<HloRecvInstruction>(shape, token, channel_id, + is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateRecvDone( - HloInstruction* operand) { + HloInstruction* operand, bool is_host_transfer) { auto recv_operand = DynCast<HloRecvInstruction>(operand); CHECK(recv_operand != nullptr) << "RecvDone must take the context operand from Recv"; - return MakeUnique<HloRecvDoneInstruction>(recv_operand); + return MakeUnique<HloRecvDoneInstruction>(recv_operand, is_host_transfer); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateReverse( @@ -684,6 +718,7 @@ HloInstruction::CreateCrossReplicaSum( /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateAfterAll( tensorflow::gtl::ArraySlice<HloInstruction*> operands) { + CHECK(!operands.empty()); auto instruction = WrapUnique( new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); for (auto operand : operands) { @@ -692,6 +727,11 @@ HloInstruction::CreateCrossReplicaSum( return instruction; } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateToken() { + return WrapUnique( + new HloInstruction(HloOpcode::kAfterAll, ShapeUtil::MakeTokenShape())); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateWhile( const Shape& shape, HloComputation* condition, HloComputation* body, HloInstruction* init) { @@ -908,6 +948,12 @@ HloInstruction::CreateBroadcastSequence( return MakeUnique<HloTransposeInstruction>(shape, operand, dimensions); } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateSort( + const Shape& shape, int64 dimension, HloInstruction* keys, + HloInstruction* values) { + return MakeUnique<HloSortInstruction>(shape, dimension, keys, values); +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateFusion( const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root) { return MakeUnique<HloFusionInstruction>(shape, fusion_kind, fused_root); @@ -952,6 +998,8 @@ bool HloInstruction::HasSideEffectNoRecurse() const { case HloOpcode::kTrace: case HloOpcode::kHostCompute: return true; + case HloOpcode::kCrossReplicaSum: + return all_reduce_id().has_value(); default: return false; } @@ -1010,34 +1058,8 @@ bool HloInstruction::HasSideEffect() const { const Shape& shape, HloInstruction* operand, HloInstruction* gather_indices, const GatherDimensionNumbers& gather_dim_numbers, tensorflow::gtl::ArraySlice<int64> window_bounds) { - std::unique_ptr<HloInstruction> instruction = - WrapUnique(new HloInstruction(HloOpcode::kGather, shape)); - instruction->AppendOperand(operand); - instruction->AppendOperand(gather_indices); - instruction->gather_dimension_numbers_ = - MakeUnique<GatherDimensionNumbers>(gather_dim_numbers); - c_copy(window_bounds, std::back_inserter(instruction->gather_window_bounds_)); - return instruction; -} - -/* static */ GatherDimensionNumbers HloInstruction::MakeGatherDimNumbers( - tensorflow::gtl::ArraySlice<int64> output_window_dims, - tensorflow::gtl::ArraySlice<int64> elided_window_dims, - tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims, - int64 index_vector_dim) { - GatherDimensionNumbers gather_dim_numbers; - for (int64 output_window_dim : output_window_dims) { - gather_dim_numbers.add_output_window_dims(output_window_dim); - } - for (int64 elided_window_dim : elided_window_dims) { - gather_dim_numbers.add_elided_window_dims(elided_window_dim); - } - for (int64 gather_dim_to_input_dim : gather_dims_to_operand_dims) { - gather_dim_numbers.add_gather_dims_to_operand_dims(gather_dim_to_input_dim); - } - - gather_dim_numbers.set_index_vector_dim(index_vector_dim); - return gather_dim_numbers; + return MakeUnique<HloGatherInstruction>(shape, operand, gather_indices, + gather_dim_numbers, window_bounds); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDomain( @@ -1100,6 +1122,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kSort: + case HloOpcode::kGather: + case HloOpcode::kIota: clone = CloneWithNewOperandsImpl(shape, new_operands, context); break; // Unary ops. @@ -1122,7 +1147,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kReal: case HloOpcode::kSign: case HloOpcode::kSin: - case HloOpcode::kSort: case HloOpcode::kTanh: CHECK_EQ(new_operands.size(), 1); clone = CreateUnary(shape, opcode_, new_operands[0]); @@ -1156,6 +1180,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: + case HloOpcode::kTupleSelect: CHECK_EQ(new_operands.size(), 3); clone = CreateTernary(shape, opcode_, new_operands[0], new_operands[1], new_operands[2]); @@ -1201,11 +1226,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( true_computation(), new_operands[2], false_computation()); break; - case HloOpcode::kGather: - CHECK_EQ(new_operands.size(), 2); - clone = CreateGather(shape, new_operands[0], new_operands[1], - *gather_dimension_numbers_, gather_window_bounds_); - break; case HloOpcode::kDomain: CHECK_EQ(new_operands.size(), 1); clone = @@ -1213,7 +1233,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( user_side_metadata_->Clone()); break; case HloOpcode::kAfterAll: - clone = CreateAfterAll(new_operands); + if (new_operands.empty()) { + clone = CreateToken(); + } else { + clone = CreateAfterAll(new_operands); + } break; } SetupDerivedInstruction(clone.get()); @@ -1495,11 +1519,10 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kSubtract: case HloOpcode::kTanh: case HloOpcode::kTuple: + case HloOpcode::kTupleSelect: return true; - // These opcodes have complex or special behavior so just return false. - case HloOpcode::kDomain: - case HloOpcode::kWhile: + // This opcode has complex or special behavior so just return false. case HloOpcode::kAfterAll: return false; @@ -1508,11 +1531,6 @@ bool HloInstruction::IdenticalSlowPath( return protobuf_util::ProtobufEquals(dot_dimension_numbers(), other.dot_dimension_numbers()); - case HloOpcode::kGather: - return protobuf_util::ProtobufEquals(gather_dimension_numbers(), - other.gather_dimension_numbers()) && - gather_window_bounds() == other.gather_window_bounds(); - // Remaining instructions with special values. case HloOpcode::kCall: return eq_computations(to_apply(), other.to_apply()); @@ -1520,9 +1538,17 @@ bool HloInstruction::IdenticalSlowPath( return eq_computations(true_computation(), other.true_computation()) && eq_computations(false_computation(), other.false_computation()); - // These opcodes are not yet supported. - case HloOpcode::kSort: + case HloOpcode::kWhile: { + if (eq_computations(while_body(), other.while_body()) && + eq_computations(while_condition(), other.while_condition())) { + return true; + } return false; + } + + case HloOpcode::kDomain: + return operand_side_metadata().Matches(other.operand_side_metadata()) && + user_side_metadata().Matches(other.user_side_metadata()); // Ops migrated to subclasses should never come to this line. // TODO(b/80131774): Remove this switch when migration is complete. @@ -1537,11 +1563,13 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kReverse: case HloOpcode::kConcatenate: case HloOpcode::kReduce: + case HloOpcode::kSort: case HloOpcode::kTranspose: case HloOpcode::kBroadcast: case HloOpcode::kMap: case HloOpcode::kSlice: case HloOpcode::kConstant: + case HloOpcode::kIota: case HloOpcode::kTrace: case HloOpcode::kFusion: case HloOpcode::kRng: @@ -1558,9 +1586,11 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kHostCompute: case HloOpcode::kPad: case HloOpcode::kDynamicSlice: + case HloOpcode::kGather: LOG(FATAL) << "Base class impl called for opcode with subclass: " << opcode(); } + return false; } void HloInstruction::RemoveUser(HloInstruction* user) { @@ -1610,8 +1640,8 @@ Status HloInstruction::ReplaceOperandWith(int64 operand_num, TF_RET_CHECK(ShapeUtil::CompatibleIgnoringFpPrecision(old_operand->shape(), new_operand->shape())) - << old_operand->shape().ShortDebugString() << " is not compatible with " - << new_operand->shape().ShortDebugString(); + << old_operand->shape() << " is not compatible with " + << new_operand->shape(); operands_[operand_num] = new_operand; VLOG(3) << "Replacing operand " << operand_num << " of " << name() << " with " @@ -1820,7 +1850,6 @@ bool HloInstruction::IsElementwiseImpl( // Ternary elementwise operations. case HloOpcode::kSelect: - return !ShapeUtil::IsTuple(shape_); case HloOpcode::kClamp: return true; @@ -1832,6 +1861,10 @@ bool HloInstruction::IsElementwiseImpl( } } +bool HloInstruction::IsCrossModuleAllReduce() const { + return opcode() == HloOpcode::kCrossReplicaSum && all_reduce_id(); +} + string HloInstruction::ToStringWithCanonicalNameMap( const HloPrintOptions& options, CanonicalNameMap* canonical_name_map) const { @@ -1924,11 +1957,6 @@ std::vector<string> HloInstruction::ExtraAttributesToString( if (dot_dimension_numbers_ != nullptr) { extra.push_back(DotDimensionNumbersToString()); } - if (gather_dimension_numbers_ != nullptr) { - extra.push_back(GatherDimensionNumbersToString()); - extra.push_back( - StrCat("window_bounds={", Join(gather_window_bounds(), ","), "}")); - } if (options.print_subcomputation_mode() == HloPrintOptions::PrintSubcomputationMode::kNameOnly) { @@ -2015,8 +2043,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString( } if (operand_side_metadata_ != nullptr && user_side_metadata_ != nullptr) { extra.push_back(StrCat("domain={kind=\"", operand_side_metadata_->Kind(), - "\", entry=", operand_side_metadata_->ToString(), - ", exit=", user_side_metadata_->ToString(), "}")); + "\", entry=", user_side_metadata_->ToString(), + ", exit=", operand_side_metadata_->ToString(), "}")); } return extra; @@ -2058,14 +2086,6 @@ HloInstructionProto HloInstruction::ToProto() const { if (dot_dimension_numbers_ != nullptr) { *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_; } - if (gather_dimension_numbers_ != nullptr) { - *proto.mutable_gather_dimension_numbers() = *gather_dimension_numbers_; - } - if (opcode() == HloOpcode::kGather) { - for (int64 bound : gather_window_bounds()) { - proto.add_gather_window_bounds(bound); - } - } if (has_sharding()) { *proto.mutable_sharding() = sharding().ToProto(); @@ -2191,6 +2211,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleRemainder(this); case HloOpcode::kSelect: return visitor->HandleSelect(this); + case HloOpcode::kTupleSelect: + return visitor->HandleTupleSelect(this); case HloOpcode::kConvolution: return visitor->HandleConvolution(this); case HloOpcode::kFft: @@ -2293,6 +2315,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleDomain(this); case HloOpcode::kAfterAll: return visitor->HandleAfterAll(this); + case HloOpcode::kIota: + return visitor->HandleIota(this); // These opcodes are not handled here. case HloOpcode::kTrace: @@ -2824,26 +2848,6 @@ std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } -string HloInstruction::GatherDimensionNumbersToString() const { - CHECK_NE(gather_dimension_numbers_.get(), nullptr); - string output_window_dims = - StrCat("output_window_dims={", - Join(gather_dimension_numbers_->output_window_dims(), ","), "}"); - string elided_window_dims = - StrCat("elided_window_dims={", - Join(gather_dimension_numbers_->elided_window_dims(), ","), "}"); - string gather_dims_to_operand_dims = StrCat( - "gather_dims_to_operand_dims={", - Join(gather_dimension_numbers_->gather_dims_to_operand_dims(), ","), "}"); - string index_vector_dim = StrCat( - "index_vector_dim=", gather_dimension_numbers_->index_vector_dim()); - - return Join<std::initializer_list<string>>( - {output_window_dims, elided_window_dims, gather_dims_to_operand_dims, - index_vector_dim}, - ", "); -} - bool HloInstruction::CouldBeBitcast() const { switch (opcode_) { case HloOpcode::kTranspose: @@ -3157,4 +3161,14 @@ int64 HloInstruction::slice_sizes(int64 dimension) const { const std::vector<int64>& HloInstruction::dynamic_slice_sizes() const { return Cast<HloDynamicSliceInstruction>(this)->dynamic_slice_sizes(); } + +const GatherDimensionNumbers& HloInstruction::gather_dimension_numbers() const { + return Cast<HloGatherInstruction>(this)->gather_dimension_numbers(); +} + +tensorflow::gtl::ArraySlice<int64> HloInstruction::gather_window_bounds() + const { + return Cast<HloGatherInstruction>(this)->gather_window_bounds(); +} + } // namespace xla |