diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-06-05 16:23:23 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-06-05 16:27:53 -0700 |
commit | a57f0de68685fb537eb390fa87f04dbafecb28ef (patch) | |
tree | 1796b4af3d5e28164f7da755d0a47b96dbd01894 /tensorflow | |
parent | 135a25971bfbac86b0aed2cf0433608966015c22 (diff) |
[XLA] Make CrossReplicaSum support general cross replica reduce. Also change the interface to be able to describe the common AllReduce semantic.
PiperOrigin-RevId: 199376926
Diffstat (limited to 'tensorflow')
12 files changed, 159 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index ae506317c2..5e17cc4dfb 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1613,13 +1613,35 @@ XlaOp XlaBuilder::BatchNormGrad(const XlaOp& operand, const XlaOp& scale, XlaOp XlaBuilder::CrossReplicaSum(const XlaOp& operand) { return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { - HloInstructionProto instr; + TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand)); + const Shape& scalar_shape = ShapeUtil::MakeShape(shape.element_type(), {}); + auto b = CreateSubBuilder("sum"); + b->Add(b->Parameter(/*parameter_number=*/0, scalar_shape, "x"), + b->Parameter(/*parameter_number=*/1, scalar_shape, "y")); + TF_ASSIGN_OR_RETURN(auto computation, b->Build()); + return CrossReplicaSum(operand, computation, /*replica_group_ids=*/{}, + /*channel_id=*/tensorflow::gtl::nullopt); + }); +} + +XlaOp XlaBuilder::CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const tensorflow::gtl::optional<ChannelHandle>& channel_id) { + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + if (!replica_group_ids.empty() || channel_id.has_value()) { + return Unimplemented( + "replica_group_ids and channel_id and is not supported in AllReduce"); + } + HloInstructionProto instr; TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand)); TF_ASSIGN_OR_RETURN( *instr.mutable_shape(), ShapeInference::InferCrossReplicaSumShape({&operand_shape})); + AddCalledComputation(computation, &instr); + return AddInstruction(std::move(instr), HloOpcode::kCrossReplicaSum, {operand}); }); diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.h b/tensorflow/compiler/xla/client/xla_client/xla_builder.h index 2b3013a91c..532cae0148 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.h @@ -532,6 +532,29 @@ class XlaBuilder { // supply one input to the sum and all replicas receive the resulting sum. XlaOp CrossReplicaSum(const XlaOp& operand); + // Enqueues an operation that do an AllReduce of the operand cross cores. Here + // AllReduce means doing a reduction on the input operand cross cores and then + // broadcasting the reduction result to those cores. The reduction function is + // defined by `computation`, which should be a commutative computation on + // scalars, e.g., add, min, or max. The way that AllReduce is applied is + // configured by: + // + // - `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // - `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce when it's ready to use. + XlaOp CrossReplicaSum( + const XlaOp& operand, const XlaComputation& computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}, + const tensorflow::gtl::optional<ChannelHandle>& channel_id = + tensorflow::gtl::nullopt); + // Enqueues an operation that scatters the `source` array to the selected // indices of each window. XlaOp SelectAndScatter(const XlaOp& operand, const XlaComputation& select, diff --git a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc index 28e71c2054..7fd1e733e9 100644 --- a/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_conversion_folding_test.cc @@ -211,6 +211,17 @@ TEST_F(BFloat16ConversionFoldingTest, DoNotFoldTuple) { TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { auto builder = HloComputation::Builder(TestName()); + + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("add"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* sum = module->AddEmbeddedComputation(sum_builder.Build()); + Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -223,7 +234,8 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b})); + ShapeUtil::MakeTupleShape({f32_shape, f32_shape}), {convert_a, b}, + sum)); HloInstruction* gte_a = builder.AddInstruction( HloInstruction::CreateGetTupleElement(f32_shape, crs, 0)); HloInstruction* gte_b = builder.AddInstruction( @@ -233,7 +245,6 @@ TEST_F(BFloat16ConversionFoldingTest, FoldCrossReplicaSumTupleOutput) { HloInstruction* tuple = builder.AddInstruction( HloInstruction::CreateTuple({gte_a, convert_gte_b})); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(FoldConversions(module.get())); diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc index 1afaefd9df..9926661dd3 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization_test.cc @@ -228,6 +228,17 @@ TEST_F(BFloat16NormalizationTest, ResolveUnsupportedMixedPrecisionReduce) { } TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { + auto module = CreateNewModule(); + HloComputation::Builder sum_builder("sum"); + auto x = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {}), "x")); + auto y = sum_builder.AddInstruction(HloInstruction::CreateParameter( + /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {}), "y")); + sum_builder.AddInstruction(HloInstruction::CreateBinary( + ShapeUtil::MakeShape(F32, {}), HloOpcode::kAdd, x, y)); + HloComputation* reduction = + module->AddEmbeddedComputation(sum_builder.Build()); + auto builder = HloComputation::Builder(TestName()); Shape f32_shape = ShapeUtil::MakeShape(F32, {2, 4}); Shape bf16_shape = ShapeUtil::MakeShape(BF16, {2, 4}); @@ -239,11 +250,11 @@ TEST_F(BFloat16NormalizationTest, ResolveMixedPrecisionTupleCrossReplicaSum) { HloInstruction* crs = builder.AddInstruction(HloInstruction::CreateCrossReplicaSum( - ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b})); + ShapeUtil::MakeTupleShape({f32_shape, bf16_shape}), {a, b}, + reduction)); HloInstruction* gte = builder.AddInstruction( HloInstruction::CreateGetTupleElement(bf16_shape, crs, 1)); - auto module = CreateNewModule(); auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(Normalize(module.get())); diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc index c0b8bf9039..682c386579 100644 --- a/tensorflow/compiler/xla/service/buffer_assignment.cc +++ b/tensorflow/compiler/xla/service/buffer_assignment.cc @@ -135,6 +135,7 @@ Status GatherComputationsByAllocationType( worklist.push_back(std::make_pair(subcomputation, false)); // Not thread local. break; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc index a8053d15e1..a23427f00c 100644 --- a/tensorflow/compiler/xla/service/call_graph.cc +++ b/tensorflow/compiler/xla/service/call_graph.cc @@ -57,6 +57,7 @@ CallContext GetInstructionCallContext(HloOpcode opcode) { case HloOpcode::kConditional: case HloOpcode::kWhile: return CallContext::kSequential; + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: case HloOpcode::kReduce: case HloOpcode::kReduceWindow: diff --git a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc index abec29df43..4ed1508d70 100644 --- a/tensorflow/compiler/xla/service/hlo_element_type_converter.cc +++ b/tensorflow/compiler/xla/service/hlo_element_type_converter.cc @@ -141,6 +141,7 @@ StatusOr<bool> HloElementTypeConverter::Run(HloModule* module) { // These are ops with embedded computations where it suffices to convert // the embedded computations instead of converting the ops themselves. if (opcode == HloOpcode::kWhile || opcode == HloOpcode::kCall || + opcode == HloOpcode::kCrossReplicaSum || opcode == HloOpcode::kFusion || opcode == HloOpcode::kMap || opcode == HloOpcode::kReduce || opcode == HloOpcode::kReduceWindow || opcode == HloOpcode::kSelectAndScatter || diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1c276b9305..06775d6a9a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -423,8 +423,20 @@ HloInstruction::CreateReducePrecision(const Shape& shape, /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCrossReplicaSum( - const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands) { - return CreateNary(shape, HloOpcode::kCrossReplicaSum, operands); + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids, + const tensorflow::gtl::optional<int64>& channel_id) { + // TODO(b/79737069): Remove the CHECK when supported. + CHECK(replica_group_ids.empty()); + CHECK(!channel_id.has_value()); + auto instruction = + WrapUnique(new HloInstruction(HloOpcode::kCrossReplicaSum, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->called_computations_.push_back(reduce_computation); + return instruction; } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateInfeed( @@ -1374,7 +1386,7 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( clone = CreateFft(shape, new_operands[0], fft_type_, fft_length_); break; case HloOpcode::kCrossReplicaSum: - clone = CreateCrossReplicaSum(shape, new_operands); + clone = CreateCrossReplicaSum(shape, new_operands, to_apply()); break; case HloOpcode::kGetTupleElement: CHECK_EQ(new_operands.size(), 1); @@ -1762,7 +1774,6 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kConvert: case HloOpcode::kCopy: case HloOpcode::kCos: - case HloOpcode::kCrossReplicaSum: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1887,6 +1898,7 @@ bool HloInstruction::IdenticalSlowPath( slice_limits_ == other.slice_limits_ && slice_strides_ == other.slice_strides_; case HloOpcode::kCall: + case HloOpcode::kCrossReplicaSum: case HloOpcode::kMap: return eq_computations(to_apply(), other.to_apply()); case HloOpcode::kCustomCall: @@ -2034,6 +2046,7 @@ HloComputation* HloInstruction::to_apply() const { case HloOpcode::kMap: case HloOpcode::kReduceWindow: case HloOpcode::kReduce: + case HloOpcode::kCrossReplicaSum: CHECK_EQ(called_computations_.size(), 1); return called_computations_[0]; default: @@ -2356,7 +2369,8 @@ std::vector<string> HloInstruction::ExtraAttributesToString( PrintName(false_computation()->name(), options))); } else if (opcode() == HloOpcode::kCall || opcode() == HloOpcode::kMap || opcode() == HloOpcode::kReduceWindow || - opcode() == HloOpcode::kReduce) { + opcode() == HloOpcode::kReduce || + opcode() == HloOpcode::kCrossReplicaSum) { extra.push_back( StrCat("to_apply=", PrintName(to_apply()->name(), options))); } else if (!called_computations().empty()) { diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 905ea5310d..ef55c6668f 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -426,10 +426,26 @@ class HloInstruction { const Shape& shape, HloInstruction* operand, const int exponent_bits, const int mantissa_bits); - // Creates a cross replica sum op. + // Creates a cross replica reduction op. + // + // `reduction_computation`: the reduction function. + // + // `replica_group_ids`: maps replica ids to subgroup ids. If empty, all + // replicas belong to one group. Allreduce will be applied within subgroups. + // For example, we have 4 replicas, then replica_group_ids={0,1,0,1} means, + // replica 0 and 2 are in subgroup 0, replica 1 and 3 are in subgroup 1. + // + // `channel_id`: for Allreduce nodes from different models, if they have the + // same channel_id, they will be 'Allreduce'd. If empty, Allreduce will not be + // applied cross models. + // + // TODO(b/79737069): Rename this to AllReduce. static std::unique_ptr<HloInstruction> CreateCrossReplicaSum( - const Shape& shape, - tensorflow::gtl::ArraySlice<HloInstruction*> operands); + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + HloComputation* reduce_computation, + tensorflow::gtl::ArraySlice<int64> replica_group_ids = {}, + const tensorflow::gtl::optional<int64>& channel_id = + tensorflow::gtl::nullopt); // Creates a conversion instruction, where operand is the data to convert and // shape is the target shape for the conversion. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index ec20606d2f..3eadedfe1f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -587,11 +587,14 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { + optional<HloComputation*> to_apply; + attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, + &to_apply}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateCrossReplicaSum(shape, operands)); + HloInstruction::CreateCrossReplicaSum(shape, operands, *to_apply)); break; } case HloOpcode::kReshape: { diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 84a981675f..08068dc504 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -900,6 +900,24 @@ ENTRY Gather { )" }, +// cross-replica-sum +{ +"CrossReplicaSum", +R"(HloModule CRS + +add { + lhs = f32[] parameter(0) + rhs = f32[] parameter(1) + ROOT add = f32[] add(lhs, rhs) +} + +ENTRY CRS { + input = f32[8]{0} parameter(0) + ROOT crs = f32[8]{0} cross-replica-sum(input), to_apply=add +} + +)" +}, }); // clang-format on } diff --git a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc index c960b3c15f..b151187c4b 100644 --- a/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc +++ b/tensorflow/compiler/xla/tests/cross_replica_sum_test.cc @@ -32,9 +32,16 @@ class TrivialCrossReplicaSumTest : public HloTestBase {}; XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p = f32[3] parameter(0) - ROOT crs = f32[3] cross-replica-sum(p) + ROOT crs = f32[3] cross-replica-sum(p), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -45,10 +52,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, OneOperand) { XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] parameter(1) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); @@ -65,10 +79,17 @@ XLA_TEST_F(TrivialCrossReplicaSumTest, MultipleOperands) { XLA_TEST_F(TrivialCrossReplicaSumTest, ConstantOperand) { const char* module_str = R"( HloModule test + + add { + x = f32[] parameter(0) + y = f32[] parameter(1) + add = f32[] add(x, y) + } + ENTRY test_computation { p0 = f32[3] parameter(0) p1 = f32[2] constant({10, 20}) - ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1) + ROOT crs = (f32[3], f32[2]) cross-replica-sum(p0, p1), to_apply=add })"; auto module = ParseHloString(module_str, GetModuleConfigForTest()).ValueOrDie(); |