aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 02:11:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 02:14:14 -0700
commit99dc8c88c465490975eb6933383a7195a5cae9a9 (patch)
tree17d1f5def0f7a10b751e361b10be9ee1f0de707b
parentea5fa07713655b93733c9aafacc8ddc904484217 (diff)
[XLA] Handle domain instructions in dataflow analysis.
Without domain propagation in dataflow analysis we end up in inconsistent domain instructions with BF16 as output and F32 as input. In case of tuple shapes these are not fixed by bfloat16_normalization, and later on they cause asserts once the domain instructions are removed. PiperOrigin-RevId: 202442786
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc48
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h1
3 files changed, 69 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
index e2ca689c06..560910cc5f 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -771,8 +771,14 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
auto computation = module->AddEntryComputation(builder.Build());
EXPECT_TRUE(PropagatePrecision(module.get()));
-
EXPECT_EQ(computation->root_instruction(), root);
+
+ // test BF16 propagated through domain
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 0).element_type(),
+ BF16);
+ EXPECT_EQ(ShapeUtil::GetTupleElementShape(domain->shape(), 1).element_type(),
+ BF16);
+
EXPECT_TRUE(OutputsBF16(a_trans));
EXPECT_TRUE(OutputsBF16(b_trans));
EXPECT_TRUE(OutputsBF16(a_gte));
@@ -781,4 +787,44 @@ TEST_F(BFloat16PropagationTest, TupleDomain) {
EXPECT_FALSE(OutputsBF16(b));
}
+// Tests that bf16 is not propagated through a domain in case its input cannot
+// be propagated. In the case below the input of the domain is the parameter
+// tuple which cannot be propagated, so the domain instruction is not propagated
+// either.
+TEST_F(BFloat16PropagationTest, TupleDomainNoPropagation) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {4, 4});
+ Shape tuple_shape = ShapeUtil::MakeTupleShape({shape, shape});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_shape, "param"));
+ HloInstruction* domain = builder.AddInstruction(
+ HloInstruction::CreateDomain(param->shape(), param, nullptr, nullptr));
+ HloInstruction* a_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 0));
+ HloInstruction* b_gte = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, domain, 1));
+ HloInstruction* a_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, a_gte, {0, 1}));
+ HloInstruction* b_trans = builder.AddInstruction(
+ HloInstruction::CreateTranspose(shape, b_gte, {0, 1}));
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kDot, a_trans, b_trans));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), root);
+ EXPECT_TRUE(OutputsBF16(a_trans));
+ EXPECT_TRUE(OutputsBF16(b_trans));
+ EXPECT_FALSE(OutputsBF16(a_gte));
+ EXPECT_FALSE(OutputsBF16(b_gte));
+ EXPECT_FALSE(OutputsBF16(domain));
+ EXPECT_FALSE(OutputsBF16(param));
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index f529c0dad7..8a4a9b5986 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -466,6 +466,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
return changed;
}
+bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
+ // Domain instructions just forward their operand. Given that domains can have
+ // a tuple operand, we iterate through its indexes, like for copies.
+ // Unlike copies though we also propagate the top-level value.
+ CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(domain)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
bool changed = false;
@@ -626,6 +644,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateBitcastValueSet(instruction);
case HloOpcode::kSlice:
return UpdateSliceValueSet(instruction);
+ case HloOpcode::kDomain:
+ return UpdateDomainValueSet(instruction);
case HloOpcode::kCopy:
return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
@@ -804,6 +824,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kGetTupleElement:
+ case HloOpcode::kDomain:
// These instructions define no values. The values in their output
// flow from their operands or from cross computation dataflow.
break;
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 3d2d5baa77..9fea218af0 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -185,6 +185,7 @@ class HloDataflowAnalysis {
bool UpdateCallValueSet(HloInstruction* call);
bool UpdateConditionalValueSet(HloInstruction* conditional);
bool UpdateCopyValueSet(HloInstruction* copy);
+ bool UpdateDomainValueSet(HloInstruction* domain);
bool UpdateGetTupleElementValueSet(HloInstruction* gte);
bool UpdateParameterValueSet(HloInstruction* parameter);
bool UpdateRecvDoneValueSet(HloInstruction* recv_done);