diff options
3 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc index e2ca689c06..560910cc5f 100644 --- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc +++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc @@ -771,8 +771,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { auto computation = module->AddEntryComputation(builder.Build()); EXPECT_TRUE(PropagatePrecision(module.get())); - EXPECT_EQ(computation->root_instruction(), root); + + // test BF16 propagated through domain + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(), + BF16); + EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(), + BF16); + EXPECT_TRUE(OutputsBF16(a_trans)); EXPECT_TRUE(OutputsBF16(b_trans)); EXPECT_TRUE(OutputsBF16(a_gte)); @@ -781,4 +787,44 @@ TEST_F(BFloat16PropagationTest, TupleDomain) { EXPECT_FALSE(OutputsBF16(b)); } +// Tests that bf16 is not propagated through a domain in case its input cannot +// be propagated. In the case below the input of the domain is the parameter +// tuple which cannot be propagated, so the domain instruction is not propagated +// either. +TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) { + auto builder = HloComputation::Builder(TestName()); + Shape shape = ShapeUtil::MakeShape(F32, {4, 4}); + Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape}); + + HloInstruction* param = builder.AddInstruction( + HloInstruction::CreateParameter(0, tuple_shape, "param")); + HloInstruction* domain = builder.AddInstruction( + HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr)); + HloInstruction* a_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 0)); + HloInstruction* b_gte = builder.AddInstruction( + HloInstruction::CreateGetTupleElement(shape, domain, 1)); + HloInstruction* a_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, a_gte, {0, 1})); + HloInstruction* b_trans = builder.AddInstruction( + HloInstruction::CreateTranspose(shape, b_gte, {0, 1})); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans)); + HloInstruction* root = builder.AddInstruction( + HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot)); + + auto module = CreateNewModule(); + auto computation = module->AddEntryComputation(builder.Build()); + + EXPECT_TRUE(PropagatePrecision(module.get())); + + EXPECT_EQ(computation->root_instruction(), root); + EXPECT_TRUE(OutputsBF16(a_trans)); + EXPECT_TRUE(OutputsBF16(b_trans)); + EXPECT_FALSE(OutputsBF16(a_gte)); + EXPECT_FALSE(OutputsBF16(b_gte)); + EXPECT_FALSE(OutputsBF16(domain)); + EXPECT_FALSE(OutputsBF16(param)); +} + } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc index f529c0dad7..8a4a9b5986 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc @@ -466,6 +466,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) { return changed; } +bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) { + // Domain instructions just forward their operand. Given that domains can have + // a tuple operand, we iterate through its indexes, like for copies. + // Unlike copies though we also propagate the top-level value. + CHECK_EQ(domain->opcode(), HloOpcode::kDomain); + bool changed = false; + for (auto& pair : GetInstructionValueSet(domain)) { + const ShapeIndex& index = pair.first; + HloValueSet& value_set = pair.second; + HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index); + if (value_set != operand_value_set) { + value_set = operand_value_set; + changed = true; + } + } + return changed; +} + bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) { CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement); bool changed = false; @@ -626,6 +644,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet( return UpdateBitcastValueSet(instruction); case HloOpcode::kSlice: return UpdateSliceValueSet(instruction); + case HloOpcode::kDomain: + return UpdateDomainValueSet(instruction); case HloOpcode::kCopy: return UpdateCopyValueSet(instruction); case HloOpcode::kGetTupleElement: @@ -804,6 +824,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kGetTupleElement: + case HloOpcode::kDomain: // These instructions define no values. The values in their output // flow from their operands or from cross computation dataflow. break; diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h index 3d2d5baa77..9fea218af0 100644 --- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h @@ -185,6 +185,7 @@ class HloDataflowAnalysis { bool UpdateCallValueSet(HloInstruction* call); bool UpdateConditionalValueSet(HloInstruction* conditional); bool UpdateCopyValueSet(HloInstruction* copy); + bool UpdateDomainValueSet(HloInstruction* domain); bool UpdateGetTupleElementValueSet(HloInstruction* gte); bool UpdateParameterValueSet(HloInstruction* parameter); bool UpdateRecvDoneValueSet(HloInstruction* recv_done); |