aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-02 22:12:33 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-02 22:16:19 -0700
commit7bb2d57b0b051d1cf8dd74d3276bf5a452774172 (patch)
treed5b07beacebcc425454978eb87ffecfe728d4281 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parent8a7f5c47dcb71deb71df4a72f3cf829904c5a28e (diff)
Rewrite CopyInsertion to use module-scoped HloAliasAnalysis. The net effect (number of copies inserted) is roughly similar to the existing implementation, but the new implementation is much more general. The new implementation can handle entry argument buffer reuse with minimal modification, for example.
Some unnecessary copies are still added due to deficiencies in buffer assignment (b/62548313), but these can be removed when buffer assignment also uses HloAliasAnalysis. Also address a few issues uncovered with this cl: (1) For inplace dynamic slice in llvm backends, truncate do not wrap the slice. This matches the behavior of the non-inplace variant. (2) Disable SelectBetweenPredTuples test on GPU. The test introduces top-level buffer ambiguity which is not tolerated by the gpu backend. (3) When deserializing HLO form a proto, do not uniquify instruction names in fused computations. (4) In dataflow analysis, don't deallocate deleted HloValues during propagation. (5) In dataflow analysis, fix issue with live_out_of_computation property. PiperOrigin-RevId: 174423881
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc64
1 files changed, 45 insertions, 19 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 92261bce62..2286cfe488 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -75,11 +75,41 @@ HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
std::forward_as_tuple(value_id, instruction, index, is_phi));
CHECK(emplaced.second);
+ VLOG(4) << "NewHloValue = " << emplaced.first->second.ToShortString();
+
return &emplaced.first->second;
}
-void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
- values_.erase(value_id);
+void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
+ HloValue& value = values_.at(value_id);
+ VLOG(4) << "MarkValueForDeletion(" << value.ToShortString() << ")";
+
+ value_ids_to_delete_.push_back(value_id);
+}
+
+void HloDataflowAnalysis::DeleteMarkedValues() {
+ // Verify that no marked-for-deletion values are in any of the value sets.
+ tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
+ value_ids_to_delete_.end());
+ for (const auto& pair : value_sets_) {
+ const HloInstruction* instruction = pair.first;
+ const InstructionValueSet& instruction_value_set = pair.second;
+ for (const auto& index_value_set : instruction_value_set) {
+ const HloValueSet& value_set = index_value_set.second;
+ for (const HloValue* value : value_set.values()) {
+ DCHECK(!ContainsKey(id_set, value->id()))
+ << "Value " << value->ToShortString()
+ << " marked for deletion, but still exists in value set for "
+ "instruction "
+ << instruction->name();
+ }
+ }
+ }
+
+ for (HloValue::Id value_id : value_ids_to_delete_) {
+ values_.erase(value_id);
+ }
+ value_ids_to_delete_.clear();
}
string HloDataflowAnalysis::ToString() const {
@@ -121,6 +151,7 @@ bool HloDataflowAnalysis::Phi(
HloInstruction* instruction,
tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
CHECK(ssa_form_);
+ VLOG(4) << "Phi(" << instruction->name() << ")";
for (const InstructionValueSet* input : inputs) {
DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
@@ -183,7 +214,7 @@ bool HloDataflowAnalysis::Phi(
} else if (current_value != &new_value) {
if (current_value_defined_here) {
// Remove the existing phi.
- DeleteHloValue(current_value->id());
+ MarkValueForDeletion(current_value->id());
}
value_set.Clear();
value_set.AddValue(&new_value);
@@ -193,7 +224,8 @@ bool HloDataflowAnalysis::Phi(
// Multiple distinct values reach this point. A phi value is
// necessary.
CHECK_GT(input_value_ids.size(), 1);
- if (current_value == nullptr || !current_value->is_phi()) {
+ if (current_value == nullptr ||
+ !(current_value->is_phi() && current_value_defined_here)) {
value_set.Clear();
value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
changed = true;
@@ -436,11 +468,13 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
}
}
-void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
- tensorflow::gtl::ArraySlice<HloInstruction*> instructions) {
+void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
- for (HloInstruction* instruction : instructions) {
- worklist.push(instruction);
+
+ for (HloComputation* computation : module_->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ worklist.push(instruction);
+ }
}
while (!worklist.empty()) {
@@ -597,18 +631,10 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
new HloDataflowAnalysis(module, ssa_form, bitcast_defines_value));
TF_RETURN_IF_ERROR(dataflow_analysis->InitializeInstructionValueSets());
+ dataflow_analysis->Propagate();
- // Construct list of all instructions to initialize the worklist to propagate
- // the data flow. For efficiency sort the instruction in post order so
- // producers appear before consumers.
- std::vector<HloInstruction*> all_instructions;
- for (const HloComputation* computation : module->MakeComputationPostOrder()) {
- for (HloInstruction* instruction :
- computation->MakeInstructionPostOrder()) {
- all_instructions.push_back(instruction);
- }
- }
- dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
+ // Delete all values marked for deletion.
+ dataflow_analysis->DeleteMarkedValues();
// Add in positions to all values.
for (const HloComputation* computation : module->computations()) {