aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar HyoukJoong Lee <hyouklee@google.com>2018-06-12 13:57:28 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-12 14:00:55 -0700
commit34b071f6b6a14bd4c8d5c30156c1670496b85f04 (patch)
treeb413892245024a28e63adf88078ad264f1cadd16 /tensorflow
parent688a09dc6b70a81cae12a7e263515964311f8d86 (diff)
Support subgroup CrossReplicaSum
PiperOrigin-RevId: 200275384
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc11
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h23
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc18
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc18
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md9
10 files changed, 108 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 5e17cc4dfb..ae8fbdb2dc 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1611,7 +1611,9 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
});
}
-XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
+XlaOp XlaBuilder::CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
@@ -1619,7 +1621,7 @@ XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
TF_ASSIGN_OR_RETURN(auto computation, b->Build());
- return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{},
+ return CrossReplicaSum(operand, computation, replica_group_ids,
/*channel_id=*/tensorflow::gtl::nullopt);
});
}
@@ -1629,7 +1631,7 @@ XlaOp XlaBuilder::CrossReplicaSum(
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- if (!replica_group_ids.empty() || channel_id.has_value()) {
+ if (channel_id.has_value()) {
return Unimplemented(
"replica_group_ids and channel_id and is not supported in AllReduce");
}
@@ -1639,6 +1641,9 @@ XlaOp XlaBuilder::CrossReplicaSum(
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
+ for (int64 replica_group_id : replica_group_ids) {
+ instr.add_replica_group_ids(replica_group_id);
+ }
AddCalledComputation(computation, &instr);
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 532cae0148..0329e42ed1 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -528,9 +528,12 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding);
- // Returns the sum of the operand value across all replicas. All replicas
- // supply one input to the sum and all replicas receive the resulting sum.
- XlaOp CrossReplicaSum(const XlaOp& operand);
+ // Returns the sum of the operand value within each subgroup of replicas. All
+ // replicas supply one input to the sum and all replicas receive the resulting
+ // sum for each subgroup.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {});
// Enqueues an operation that do an AllReduce of the operand cross cores. Here
// AllReduce means doing a reduction on the input operand cross cores and then
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 7fd1e733e9..f7b4c1405d 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -235,7 +235,7 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
- sum));
+ sum, /*replica_group_ids=*/{}, /*barrier=*/""));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 9926661dd3..830f26422b 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -250,8 +250,8 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
- ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b},
- reduction));
+ ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, reduction,
+ /*replica_group_ids=*/{}, /*barrier=*/""));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 1f7c1cffd3..e201359d3d 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -145,6 +145,7 @@ message HloInstructionProto {
repeated int64 operand_ids = 36;
repeated int64 control_predecessor_ids = 37;
repeated int64 called_computation_ids = 38;
+ repeated int64 replica_group_ids = 44;
xla.OpSharding sharding = 40;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 28b6d6aefd..a9e73d3a77 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -298,6 +298,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->channel_name_ = proto.channel_name();
instruction->cost_estimate_ns_ = proto.cost_estimate_ns();
+ for (int64 replica_group_id : proto.replica_group_ids()) {
+ instruction->replica_group_ids_.push_back(replica_group_id);
+ }
+
return std::move(instruction);
}
@@ -528,9 +532,9 @@ HloInstruction::CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
const tensorflow::gtl::optional<int64>& channel_id) {
// TODO(b/79737069): Remove the CHECK when supported.
- CHECK(replica_group_ids.empty());
CHECK(!channel_id.has_value());
auto instruction =
WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
@@ -538,6 +542,9 @@ HloInstruction::CreateCrossReplicaSum(
instruction->AppendOperand(operand);
}
instruction->called_computations_.push_back(reduce_computation);
+ instruction->replica_group_ids_.assign(replica_group_ids.begin(),
+ replica_group_ids.end());
+ instruction->cross_replica_sum_barrier_ = std::string(barrier);
return instruction;
}
@@ -1138,7 +1145,9 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
*dot_dimension_numbers_);
break;
case HloOpcode::kCrossReplicaSum:
- clone = CreateCrossReplicaSum(shape, new_operands, to_apply());
+ clone =
+ CreateCrossReplicaSum(shape, new_operands, to_apply(),
+ replica_group_ids_, cross_replica_sum_barrier_);
break;
case HloOpcode::kGetTupleElement:
CHECK_EQ(new_operands.size(), 1);
@@ -1507,7 +1516,9 @@ bool HloInstruction::IdenticalSlowPath(
other.padding_config());
case HloOpcode::kCall:
case HloOpcode::kCrossReplicaSum:
- return eq_computations(to_apply(), other.to_apply());
+ return replica_group_ids() == other.replica_group_ids() &&
+ cross_replica_sum_barrier() == other.cross_replica_sum_barrier() &&
+ eq_computations(to_apply(), other.to_apply());
case HloOpcode::kCustomCall:
if ((window_ == nullptr) != (other.window_ == nullptr) ||
(window_ != nullptr &&
@@ -2086,6 +2097,14 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
"\", entry=", operand_side_metadata_->ToString(),
", exit=", user_side_metadata_->ToString(), "}"));
}
+ if (!replica_group_ids().empty()) {
+ extra.push_back(
+ StrCat("replica_group_ids={", Join(replica_group_ids(), ","), "}"));
+ }
+ if (!cross_replica_sum_barrier().empty()) {
+ extra.push_back(StrCat("barrier=\"", cross_replica_sum_barrier(), "\""));
+ }
+
// By contract, we print the custom call target even if
// options.print_subcomputation_mode() == kOff, because the call target is not
// an HloComputation.
@@ -2173,6 +2192,9 @@ HloInstructionProto HloInstruction::ToProto() const {
proto.set_channel_name(channel_name_);
proto.set_cost_estimate_ns(cost_estimate_ns_);
+ for (int64 replica_group_id : replica_group_ids_) {
+ proto.add_replica_group_ids(replica_group_id);
+ }
return proto;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 7d1ea129df..fcd175e66f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -443,7 +443,8 @@ class HloInstruction {
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* reduce_computation,
- tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ tensorflow::StringPiece barrier,
const tensorflow::gtl::optional<int64>& channel_id =
tensorflow::gtl::nullopt);
@@ -1447,6 +1448,20 @@ class HloInstruction {
void set_fusion_kind(FusionKind kind);
// Old methods kept for smooth subclassing transition END.
+ // Returns the group ids of each replica for CrossReplicaSum op.
+ const std::vector<int64>& replica_group_ids() const {
+ return replica_group_ids_;
+ }
+
+ // Returns the barrier config used for the CrossReplicaSum implementation of
+ // each backend.
+ string cross_replica_sum_barrier() const {
+ return cross_replica_sum_barrier_;
+ }
+ void set_cross_replica_sum_barrier(string barrier) {
+ cross_replica_sum_barrier_ = barrier;
+ }
+
protected:
enum class UseKind { kNoUse, kReuse, kUsePermutingElements, kUse };
// Helper class for computing OperandElementUse for kFusion.
@@ -1650,6 +1665,12 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
+ // The group id of each replica for CrossReplicaSum.
+ std::vector<int64> replica_group_ids_;
+
+ // The string representation of the barrier config used for CrossReplicaSum.
+ string cross_replica_sum_barrier_;
+
// String identifier for instruction.
string name_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 4aa4406292..fef475380c 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -588,13 +588,27 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCrossReplicaSum: {
optional<HloComputation*> to_apply;
+ optional<std::vector<int64>> replica_group_ids;
+ optional<string> barrier;
attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
&to_apply};
+ attrs["replica_group_ids"] = {
+ /*required=*/false, AttrTy::kBracedInt64List, &replica_group_ids};
+ attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply));
+
+ if (replica_group_ids) {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, *replica_group_ids,
+ barrier ? *barrier : ""));
+ } else {
+ instruction =
+ builder->AddInstruction(HloInstruction::CreateCrossReplicaSum(
+ shape, operands, *to_apply, {}, barrier ? *barrier : ""));
+ }
break;
}
case HloOpcode::kReshape: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 1c5a47c875..f834d34d57 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -918,6 +918,24 @@ ENTRY CRS {
)"
},
+// cross-replica-sum with subgroups
+{
+"CrossReplicaSumWithSubgroups",
+R"(HloModule CRS_Subgroups
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CrossReplicaSumWithSubgroups {
+ input = f32[128,32]{0,1} parameter(0)
+ ROOT cross-replica-sum = f32[128,32]{0,1} cross-replica-sum(input), to_apply=add, replica_group_ids={0,0,1,1}, barrier="abc"
+}
+
+)"
+}
});
// clang-format on
}
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 5887c3d88b..f7e116bf0f 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -581,12 +581,21 @@ Computes a sum across replicas.
Arguments | Type | Semantics
--------- | ------- | -----------------------------
`operand` | `XlaOp` | Array to sum across replicas.
+| `replica_group_ids` | `int64` vector | Group ID for each replica. |
The output shape is the same as the input shape. For example, if there are two
replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)`
respectively on the two replicas, then the output value from this op will be
`(4.0, 7.75)` on both replicas.
+`replica_group_ids` identifies the group ID of each replica. The group ID must
+either be empty (all replicas belong to a single group), or contain the same
+number of elements as the number of replicas. For example, if
+`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are
+four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of
+each subgroup *must* be identical, so, for example, using:
+`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid.
+
Computing the result of CrossReplicaSum requires having one input from each
replica, so if one replica executes a CrossReplicaSum node more times than
another, then the former replica will wait forever. Since the replicas are all