aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc24
1 files changed, 19 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1c276b9305..06775d6a9a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -423,8 +423,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<int64>& channel_id) {
+ // TODO(b/79737069): Remove the CHECK when supported.
+ CHECK(replica_group_ids.empty());
+ CHECK(!channel_id.has_value());
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->called_computations_.push_back(reduce_computation);
+ return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
@@ -1374,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
break;
case HloOpcode::kCrossReplicaSum:
- clone = CreateCrossReplicaSum(shape, new_operands);
+ clone = CreateCrossReplicaSum(shape, new_operands, to_apply());
break;
case HloOpcode::kGetTupleElement:
CHECK_EQ(new_operands.size(), 1);
@@ -1762,7 +1774,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
- case HloOpcode::kCrossReplicaSum:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
@@ -1887,6 +1898,7 @@ bool HloInstruction::IdenticalSlowPath(
slice_limits_ == other.slice_limits_ &&
slice_strides_ == other.slice_strides_;
case HloOpcode::kCall:
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kCustomCall:
@@ -2034,6 +2046,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -2356,7 +2369,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
PrintName(false_computation()->name(), options)));
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
- opcode() == HloOpcode::kReduce) {
+ opcode() == HloOpcode::kReduce ||
+ opcode() == HloOpcode::kCrossReplicaSum) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {