aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h1
3 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index e2ca689c06..560910cc5f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -771,8 +771,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(PropagatePrecision(module.get()));
-
EXPECT_EQ(computation->root_instruction(), root);
+
+ // test BF16 propagated through domain
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
+ BF16);
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
+ BF16);
+
EXPECT_TRUE(OutputsBF16(a_trans));
EXPECT_TRUE(OutputsBF16(b_trans));
EXPECT_TRUE(OutputsBF16(a_gte));
@@ -781,4 +787,44 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
EXPECT_FALSE(OutputsBF16(b));
}
+// Tests that bf16 is not propagated through a domain in case its input cannot
+// be propagated. In the case below the input of the domain is the parameter
+// tuple which cannot be propagated, so the domain instruction is not propagated
+// either.
+TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ HloInstruction* domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
+ HloInstruction* a_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 0));
+ HloInstruction* b_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 1));
+ HloInstruction* a_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
+ HloInstruction* b_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), root);
+ EXPECT_TRUE(OutputsBF16(a_trans));
+ EXPECT_TRUE(OutputsBF16(b_trans));
+ EXPECT_FALSE(OutputsBF16(a_gte));
+ EXPECT_FALSE(OutputsBF16(b_gte));
+ EXPECT_FALSE(OutputsBF16(domain));
+ EXPECT_FALSE(OutputsBF16(param));
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index f529c0dad7..8a4a9b5986 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -466,6 +466,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
return changed;
}
+bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
+ // Domain instructions just forward their operand. Given that domains can have
+ // a tuple operand, we iterate through its indexes, like for copies.
+ // Unlike copies though we also propagate the top-level value.
+ CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(domain)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
bool changed = false;
@@ -626,6 +644,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateBitcastValueSet(instruction);
case HloOpcode::kSlice:
return UpdateSliceValueSet(instruction);
+ case HloOpcode::kDomain:
+ return UpdateDomainValueSet(instruction);
case HloOpcode::kCopy:
return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
@@ -804,6 +824,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kGetTupleElement:
+ case HloOpcode::kDomain:
// These instructions define no values. The values in their output
// flow from their operands or from cross computation dataflow.
break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 3d2d5baa77..9fea218af0 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -185,6 +185,7 @@ class HloDataflowAnalysis {
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
+ bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);