aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-05 16:23:23 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-05 16:27:53 -0700
commita57f0de68685fb537eb390fa87f04dbafecb28ef (patch)
tree1796b4af3d5e28164f7da755d0a47b96dbd01894 /tensorflow
parent135a25971bfbac86b0aed2cf0433608966015c22 (diff)
[XLA] Make CrossReplicaSum support general cross replica reduce. Also change the interface to be able to describe the common AllReduce semantic.
PiperOrigin-RevId: 199376926
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc24
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.h23
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc1
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_element_type_converter.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/cross_replica_sum_test.cc27
12 files changed, 159 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index ae506317c2..5e17cc4dfb 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale,
XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
+ TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
+ const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {});
+ auto b = CreateSubBuilder("sum");
+ b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"),
+ b->Parameter(/*parameter_number=*/1, scalar_shape, "y"));
+ TF_ASSIGN_OR_RETURN(auto computation, b->Build());
+ return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{},
+ /*channel_id=*/tensorflow::gtl::nullopt);
+ });
+}
+
+XlaOp XlaBuilder::CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id) {
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ if (!replica_group_ids.empty() || channel_id.has_value()) {
+ return Unimplemented(
+ "replica_group_ids and channel_id and is not supported in AllReduce");
+ }
+ HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCrossReplicaSumShape({&operand_shape}));
+ AddCalledComputation(computation, &instr);
+
return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum,
{operand});
});
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
index 2b3013a91c..532cae0148 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h
@@ -532,6 +532,29 @@ class XlaBuilder {
// supply one input to the sum and all replicas receive the resulting sum.
XlaOp CrossReplicaSum(const XlaOp& operand);
+ // Enqueues an operation that do an AllReduce of the operand cross cores. Here
+ // AllReduce means doing a reduction on the input operand cross cores and then
+ // broadcasting the reduction result to those cores. The reduction function is
+ // defined by `computation`, which should be a commutative computation on
+ // scalars, e.g., add, min, or max. The way that AllReduce is applied is
+ // configured by:
+ //
+ // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+ // replicas belong to one group. Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ //
+ // - `channel_id`: for Allreduce nodes from different models, if they have the
+ // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross models.
+ //
+ // TODO(b/79737069): Rename this to AllReduce when it's ready to use.
+ XlaOp CrossReplicaSum(
+ const XlaOp& operand, const XlaComputation& computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<ChannelHandle>& channel_id =
+ tensorflow::gtl::nullopt);
+
// Enqueues an operation that scatters the `source` array to the selected
// indices of each window.
XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select,
diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
index 28e71c2054..7fd1e733e9 100644
--- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc
@@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) {
TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
auto builder = HloComputation::Builder(TestName());
+
+ auto module = CreateNewModule();
+ HloComputation::Builder sum_builder("add");
+ auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
+ auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
+ sum_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
+ HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build());
+
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
- ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}));
+ ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b},
+ sum));
HloInstruction* gte_a = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(f32_shape, crs, 0));
HloInstruction* gte_b = builder.AddInstruction(
@@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) {
HloInstruction* tuple = builder.AddInstruction(
HloInstruction::CreateTuple({gte_a, convert_gte_b}));
- auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(FoldConversions(module.get()));
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
index 1afaefd9df..9926661dd3 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc
@@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) {
}
TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
+ auto module = CreateNewModule();
+ HloComputation::Builder sum_builder("sum");
+ auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x"));
+ auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y"));
+ sum_builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y));
+ HloComputation* reduction =
+ module->AddEmbeddedComputation(sum_builder.Build());
+
auto builder = HloComputation::Builder(TestName());
Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4});
Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4});
@@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) {
HloInstruction* crs =
builder.AddInstruction(HloInstruction::CreateCrossReplicaSum(
- ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}));
+ ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b},
+ reduction));
HloInstruction* gte = builder.AddInstruction(
HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1));
- auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(Normalize(module.get()));
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index c0b8bf9039..682c386579 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType(
worklist.push_back(std::make_pair(subcomputation,
false)); // Not thread local.
break;
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index a8053d15e1..a23427f00c 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) {
case HloOpcode::kConditional:
case HloOpcode::kWhile:
return CallContext::kSequential;
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
case HloOpcode::kReduce:
case HloOpcode::kReduceWindow:
diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
index abec29df43..4ed1508d70 100644
--- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
+++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc
@@ -141,6 +141,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) {
// These are ops with embedded computations where it suffices to convert
// the embedded computations instead of converting the ops themselves.
if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall ||
+ opcode == HloOpcode::kCrossReplicaSum ||
opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap ||
opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow ||
opcode == HloOpcode::kSelectAndScatter ||
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1c276b9305..06775d6a9a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -423,8 +423,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape,
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateCrossReplicaSum(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
- return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids,
+ const tensorflow::gtl::optional<int64>& channel_id) {
+ // TODO(b/79737069): Remove the CHECK when supported.
+ CHECK(replica_group_ids.empty());
+ CHECK(!channel_id.has_value());
+ auto instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->called_computations_.push_back(reduce_computation);
+ return instruction;
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed(
@@ -1374,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_);
break;
case HloOpcode::kCrossReplicaSum:
- clone = CreateCrossReplicaSum(shape, new_operands);
+ clone = CreateCrossReplicaSum(shape, new_operands, to_apply());
break;
case HloOpcode::kGetTupleElement:
CHECK_EQ(new_operands.size(), 1);
@@ -1762,7 +1774,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kConvert:
case HloOpcode::kCopy:
case HloOpcode::kCos:
- case HloOpcode::kCrossReplicaSum:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
@@ -1887,6 +1898,7 @@ bool HloInstruction::IdenticalSlowPath(
slice_limits_ == other.slice_limits_ &&
slice_strides_ == other.slice_strides_;
case HloOpcode::kCall:
+ case HloOpcode::kCrossReplicaSum:
case HloOpcode::kMap:
return eq_computations(to_apply(), other.to_apply());
case HloOpcode::kCustomCall:
@@ -2034,6 +2046,7 @@ HloComputation* HloInstruction::to_apply() const {
case HloOpcode::kMap:
case HloOpcode::kReduceWindow:
case HloOpcode::kReduce:
+ case HloOpcode::kCrossReplicaSum:
CHECK_EQ(called_computations_.size(), 1);
return called_computations_[0];
default:
@@ -2356,7 +2369,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString(
PrintName(false_computation()->name(), options)));
} else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap ||
opcode() == HloOpcode::kReduceWindow ||
- opcode() == HloOpcode::kReduce) {
+ opcode() == HloOpcode::kReduce ||
+ opcode() == HloOpcode::kCrossReplicaSum) {
extra.push_back(
StrCat("to_apply=", PrintName(to_apply()->name(), options)));
} else if (!called_computations().empty()) {
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 905ea5310d..ef55c6668f 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -426,10 +426,26 @@ class HloInstruction {
const Shape& shape, HloInstruction* operand, const int exponent_bits,
const int mantissa_bits);
- // Creates a cross replica sum op.
+ // Creates a cross replica reduction op.
+ //
+ // `reduction_computation`: the reduction function.
+ //
+ // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all
+ // replicas belong to one group. Allreduce will be applied within subgroups.
+ // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means,
+ // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1.
+ //
+ // `channel_id`: for Allreduce nodes from different models, if they have the
+ // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be
+ // applied cross models.
+ //
+ // TODO(b/79737069): Rename this to AllReduce.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> operands);
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ HloComputation* reduce_computation,
+ tensorflow::gtl::ArraySlice<int64> replica_group_ids = {},
+ const tensorflow::gtl::optional<int64>& channel_id =
+ tensorflow::gtl::nullopt);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index ec20606d2f..3eadedfe1f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -587,11 +587,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
+ optional<HloComputation*> to_apply;
+ attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
+ &to_apply};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateCrossReplicaSum(shape, operands));
+ HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply));
break;
}
case HloOpcode::kReshape: {
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 84a981675f..08068dc504 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -900,6 +900,24 @@ ENTRY Gather {
)"
},
+// cross-replica-sum
+{
+"CrossReplicaSum",
+R"(HloModule CRS
+
+add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+}
+
+ENTRY CRS {
+ input = f32[8]{0} parameter(0)
+ ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add
+}
+
+)"
+},
});
// clang-format on
}
diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
index c960b3c15f..b151187c4b 100644
--- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
+++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc
@@ -32,9 +32,16 @@ class TrivialCrossReplicaSumTest : public HloTestBase {};
XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p = f32[3] parameter(0)
- ROOT crs = f32[3] cross-replica-sum(p)
+ ROOT crs = f32[3] cross-replica-sum(p), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
@@ -45,10 +52,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) {
XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] parameter(1)
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();
@@ -65,10 +79,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) {
XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) {
const char* module_str = R"(
HloModule test
+
+ add {
+ x = f32[] parameter(0)
+ y = f32[] parameter(1)
+ add = f32[] add(x, y)
+ }
+
ENTRY test_computation {
p0 = f32[3] parameter(0)
p1 = f32[2] constant({10, 20})
- ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1)
+ ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add
})";
auto module =
ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie();