aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-06-28 02:11:38 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-28 02:14:14 -0700
commit99dc8c88c465490975eb6933383a7195a5cae9a9 (patch)
tree17d1f5def0f7a10b751e361b10be9ee1f0de707b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parentea5fa07713655b93733c9aafacc8ddc904484217 (diff)
[XLA] Handle domain instructions in dataflow analysis.
Without domain propagation in dataflow analysis we end up in inconsistent domain instructions with BF16 as output and F32 as input. In case of tuple shapes these are not fixed by bfloat16_normalization, and later on they cause asserts once the domain instructions are removed. PiperOrigin-RevId: 202442786
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc21
1 files changed, 21 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index f529c0dad7..8a4a9b5986 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -466,6 +466,24 @@ bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
return changed;
}
+bool HloDataflowAnalysis::UpdateDomainValueSet(HloInstruction* domain) {
+ // Domain instructions just forward their operand. Given that domains can have
+ // a tuple operand, we iterate through its indexes, like for copies.
+ // Unlike copies though we also propagate the top-level value.
+ CHECK_EQ(domain->opcode(), HloOpcode::kDomain);
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(domain)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ HloValueSet& operand_value_set = GetValueSet(domain->operand(0), index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
+}
+
bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
bool changed = false;
@@ -626,6 +644,8 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
return UpdateBitcastValueSet(instruction);
case HloOpcode::kSlice:
return UpdateSliceValueSet(instruction);
+ case HloOpcode::kDomain:
+ return UpdateDomainValueSet(instruction);
case HloOpcode::kCopy:
return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
@@ -804,6 +824,7 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kGetTupleElement:
+ case HloOpcode::kDomain:
// These instructions define no values. The values in their output
// flow from their operands or from cross computation dataflow.
break;