diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 24 |
1 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1c276b9305..06775d6a9a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -423,8 +423,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const tensorflow::gtl::optional<int64>& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(replica_group_ids.empty()); + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + return instruction; } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( @@ -1374,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); + clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1762,7 +1774,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1887,6 +1898,7 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kCall: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: @@ -2034,6 +2046,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2356,7 +2369,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { |