aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-08-09 15:01:40 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-08-09 15:04:53 -0700
commit56633fe0cecba03929738df0a0788216f57cf8e9 (patch)
tree7c0be8c004e12172e4e33f98eb1307a395454247 /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parent83accbb3745b4019c21f59c3e6f9ab92250261ba (diff)
Make HloDataFlowAnalysis updatable after transforming the HLO graph.
Updating is possible if operands/uses or computation roots change in the graph. Updating is not possible if instructions are deleted or if new instructions are added. Specific changes: * Add verification methods for asserting invariants and checking the analysis after updating. * Always add phi values at while instructions. Previously these were added only if the phi had different inputs. The advantage of using phi's unconditionally is that the set of values is fixed for a module. Updates due to changing operands/uses in the graph do not create new values. * Store values in a vector rather than a map. With unconditional phi values, the number of HloValues is fixed so the values can be held in a vector with stable references to elements. PiperOrigin-RevId: 164778750
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc778
1 files changed, 487 insertions, 291 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 92548dfaf0..ea8b239e10 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -67,30 +67,6 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
return GetUniqueValueAt(instruction, index);
}
-HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
- const ShapeIndex& index,
- bool is_phi) {
- const int64 value_id = next_value_id_++;
- auto emplaced = values_.emplace(
- std::piecewise_construct, std::forward_as_tuple(value_id),
- std::forward_as_tuple(value_id, instruction, index, is_phi));
- CHECK(emplaced.second);
-
- // Clear the vector of values as it is now stale. It will be lazily
- // reconstructed if needed when HloDataflowAnalysis::values() is called.
- values_vector_.clear();
-
- return &emplaced.first->second;
-}
-
-void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
- values_.erase(value_id);
-
- // Clear the vector of values as it is now stale. It will be lazily
- // reconstructed if needed when HloDataflowAnalysis::values() is called.
- values_vector_.clear();
-}
-
string HloDataflowAnalysis::ToString() const {
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
@@ -123,8 +99,18 @@ string HloDataflowAnalysis::ToString() const {
}
}
StrAppend(&out, " HloValues:\n");
- for (const auto& pair : values_) {
- StrAppend(&out, pair.second.ToString(/*indent=*/4));
+ for (const HloValue& value : values()) {
+ StrAppend(&out, value.ToString(/*indent=*/4));
+ }
+ StrAppend(&out, " Phi resolutions:\n");
+ for (const HloValue& value : values()) {
+ if (value.is_phi()) {
+ const HloValue* resolved_value = ResolvePhi(value);
+ StrAppend(&out, " ", value.ToShortString(), " => ",
+ resolved_value == nullptr ? "UNKNOWN"
+ : resolved_value->ToShortString(),
+ "\n");
+ }
}
return out;
}
@@ -147,253 +133,343 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloInstruction* instruction,
return *GetInstructionValueSet(instruction).mutable_element(index);
}
-const std::vector<const HloValue*>& HloDataflowAnalysis::values() const {
- if (values_vector_.empty()) {
- // Lazily construct vector of values.
- values_vector_.reserve(values_.size());
- for (auto& pair : values_) {
- values_vector_.push_back(&pair.second);
+const HloValueSet& HloDataflowAnalysis::GetValueSet(
+ const HloPosition& position) const {
+ return GetValueSet(position.instruction, position.index);
+}
+
+HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
+ return GetValueSet(position.instruction, position.index);
+}
+
+void HloDataflowAnalysis::UpdateAfterChangingOperand(
+ HloInstruction* instruction, HloInstruction* old_operand,
+ HloInstruction* new_operand) {
+ CHECK(std::find(instruction->operands().begin(),
+ instruction->operands().end(),
+ new_operand) != instruction->operands().end());
+ VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
+ << old_operand->name() << " => " << new_operand->name() << ")";
+
+ std::vector<HloInstruction*> to_update = {instruction};
+
+ // If the instruction calls any computations then add the parameters of called
+ // computation to capture any changes to the dataflow into the subcomputation
+ // introduced by the new operand.
+ for (HloComputation* computation : instruction->called_computations()) {
+ to_update.insert(to_update.end(),
+ computation->parameter_instructions().begin(),
+ computation->parameter_instructions().end());
+ }
+
+ UpdateInstructionsAndPropagate(to_update);
+
+ // The uses of the values in the old and new operand may have changed. Uses of
+ // other HloValues are updated in UpdateInstructionsAndPropagate.
+ for (auto& pair : GetInstructionValueSet(old_operand)) {
+ for (const HloValue* value : pair.second.values()) {
+ GetValue(value->id()).RecomputeUses();
}
- std::sort(values_vector_.begin(), values_vector_.end(),
- HloValue::IdLessThan);
- } else {
- CHECK_EQ(values_vector_.size(), values_.size());
- for (const HloValue* value : values_vector_) {
- DCHECK(ContainsKey(values_, value->id()));
- DCHECK(&GetValue(value->id()) == value);
+ }
+ for (auto& pair : GetInstructionValueSet(new_operand)) {
+ for (const HloValue* value : pair.second.values()) {
+ GetValue(value->id()).RecomputeUses();
}
}
- return values_vector_;
+
+ TF_DCHECK_OK(VerifyAgainstReference());
}
-/* static */
-InstructionValueSet HloDataflowAnalysis::Phi(
- HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs,
- bool skip_top_level) {
- CHECK(ssa_form_);
+void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
+ HloInstruction* new_root) {
+ VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
+ << new_root->name() << ")";
- for (const InstructionValueSet* input : inputs) {
- CHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
- }
- InstructionValueSet new_value_set(instruction->shape());
- new_value_set.ForEachMutableElement(
- [this, instruction, &inputs, skip_top_level](const ShapeIndex& index,
- HloValueSet* value_set) {
- // If we're skipping the top level, just copy over the existing
- // HloValueSet.
- if (skip_top_level && index.empty()) {
- *value_set = GetInstructionValueSet(instruction).element(index);
- return;
- }
+ CHECK_EQ(new_root, new_root->parent()->root_instruction());
+ CHECK_EQ(new_root->parent(), old_root->parent());
- // Identify the existing phi value at this index if it exists.
- const HloValue* existing_phi_value = nullptr;
- if (ValueIsDefinedAt(instruction, index) &&
- GetUniqueValueAt(instruction, index).is_phi()) {
- existing_phi_value = &GetUniqueValueAt(instruction, index);
- }
+ std::vector<HloInstruction*> to_update = {old_root, new_root};
- // Construct a vector of unique value IDs of the inputs.
- std::vector<HloValue::Id> input_value_ids;
- for (const InstructionValueSet* input : inputs) {
- for (const HloValue* value : input->element(index).values()) {
- input_value_ids.push_back(value->id());
- }
- }
- std::sort(input_value_ids.begin(), input_value_ids.end());
- input_value_ids.erase(
- std::unique(input_value_ids.begin(), input_value_ids.end()),
- input_value_ids.end());
-
- // Remove the existing phi value (if it exists). The phi can be its own
- // input, for example, in while body parameters where the body passes
- // through the parameter value.
- if (existing_phi_value != nullptr) {
- auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
- existing_phi_value->id());
- if (it != input_value_ids.end()) {
- input_value_ids.erase(it);
- }
- }
+ const CallGraphNode& call_graph_node =
+ call_graph_->GetNode(new_root->parent());
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kCall) {
+ to_update.push_back(callsite.instruction());
+ } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
+ // Add the while itself, and the body and condition parameters.
+ to_update.push_back(callsite.instruction());
+ to_update.push_back(
+ callsite.instruction()->while_body()->parameter_instruction(0));
+ to_update.push_back(
+ callsite.instruction()->while_condition()->parameter_instruction(0));
+ }
+ }
- if (input_value_ids.size() <= 1) {
- if (input_value_ids.size() == 1) {
- *value_set = HloValueSet({&GetValue(input_value_ids[0])});
- }
- if (existing_phi_value) {
- // The merge point does not have multiple distinct inputs (which are
- // not the phi value itself). Therefore there is no need to insert a
- // phi value because there is a single reaching definition (or no
- // reaching definition).
- DeleteHloValue(existing_phi_value->id());
- }
- } else if (input_value_ids.size() > 1) {
- // Multiple distinct values reach this point. A phi value is
- // necessary.
- if (existing_phi_value) {
- // A phi value already exists so reuse it in the new
- // InstructionValueSet.
- *value_set = HloValueSet({existing_phi_value});
- } else {
- // Create a new phi value.
- *value_set =
- HloValueSet({NewHloValue(instruction, index, /*is_phi=*/true)});
- }
- }
- });
- return new_value_set;
-}
-
-void HloDataflowAnalysis::UpdatePositionsOfValuesAt(
- HloInstruction* instruction, const InstructionValueSet& new_value_set,
- const InstructionValueSet* prev_value_set) {
- if (prev_value_set != nullptr) {
- // Remove positions from the old value set.
- prev_value_set->ForEachElement(
- [this, instruction](const ShapeIndex& index,
- const HloValueSet& value_set) {
- for (const HloValue* value : value_set.values()) {
- // HloValues in the previous value set may have been deleted.
- if (!ContainsKey(values_, value->id())) {
- continue;
- }
- // Don't remove the defining position of the value.
- if (instruction == value->defining_instruction()) {
- CHECK_EQ(index, value->defining_index());
- } else {
- GetValue(value->id()).RemovePosition(instruction, index);
- }
- }
- });
+ UpdateInstructionsAndPropagate(to_update);
+
+ TF_DCHECK_OK(VerifyAgainstReference());
+}
+
+const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
+ CHECK(phi.is_phi());
+
+ tensorflow::gtl::FlatSet<const HloValue*> visited;
+ std::queue<const HloValue*> worklist;
+ auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
+ if (visited.insert(v).second) {
+ // 'v' was not previously in visited.
+ worklist.push(v);
+ }
+ };
+ add_to_worklist(&phi);
+
+ const HloValue* resolved_value = nullptr;
+ while (!worklist.empty()) {
+ const HloValue* value = worklist.front();
+ worklist.pop();
+
+ if (!value->is_phi()) {
+ if (resolved_value == nullptr) {
+ resolved_value = value;
+ } else if (resolved_value != value) {
+ return nullptr;
+ }
+ } else {
+ for (const HloValue* input : phi_inputs_.at(value)) {
+ add_to_worklist(input);
+ }
+ }
}
- // Add positions in the new value set.
- new_value_set.ForEachElement(
- [this, instruction](const ShapeIndex& index,
- const HloValueSet& value_set) {
- for (const HloValue* value : value_set.values()) {
- if (instruction == value->defining_instruction()) {
- CHECK_EQ(index, value->defining_index());
- } else {
- GetValue(value->id()).AddPosition(instruction, index);
- }
+ return resolved_value;
+}
+
+void HloDataflowAnalysis::UpdatePhiInputs(
+ const HloInstruction* instruction,
+ tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ CHECK(ssa_form_);
+ for (auto& pair : GetInstructionValueSet(instruction)) {
+ const ShapeIndex& index = pair.first;
+ const HloValue& phi_value = GetUniqueValueAt(instruction, index);
+ auto& phi_inputs = phi_inputs_.at(&phi_value);
+ phi_inputs.clear();
+ for (const InstructionValueSet* input : inputs) {
+ for (const HloValue* value : input->element(index).values()) {
+ // The number of phi inputs is typically 2, and virtually always very
+ // small.
+ if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
+ phi_inputs.end()) {
+ phi_inputs.push_back(value);
}
- });
+ }
+ }
+ }
}
-InstructionValueSet HloDataflowAnalysis::RecomputeBitcastValueSet(
- HloInstruction* bitcast) {
+bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
- if (bitcast_defines_value_) {
- return GetInstructionValueSet(bitcast);
- } else {
- return GetInstructionValueSet(bitcast->operand(0));
+ const InstructionValueSet& operand_set =
+ GetInstructionValueSet(bitcast->operand(0));
+ InstructionValueSet& bitcast_set = GetInstructionValueSet(bitcast);
+ if (!bitcast_defines_value_ && operand_set != bitcast_set) {
+ bitcast_set = operand_set;
+ return true;
}
+ return false;
}
-InstructionValueSet HloDataflowAnalysis::RecomputeCopyValueSet(
- HloInstruction* copy) {
+bool HloDataflowAnalysis::UpdateCallValueSet(HloInstruction* call) {
+ CHECK_EQ(call->opcode(), HloOpcode::kCall);
+ InstructionValueSet& value_set = GetInstructionValueSet(call);
+ InstructionValueSet& root_value_set =
+ GetInstructionValueSet(call->to_apply()->root_instruction());
+ if (value_set != root_value_set) {
+ value_set = root_value_set;
+ return true;
+ }
+ return false;
+}
+
+bool HloDataflowAnalysis::UpdateCopyValueSet(HloInstruction* copy) {
CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
- InstructionValueSet new_value_set = GetInstructionValueSet(copy);
- if (ShapeUtil::IsTuple(copy->shape())) {
- for (int i = 0; i < ShapeUtil::TupleElementCount(copy->shape()); ++i) {
- new_value_set.CopySubtreeFrom(GetInstructionValueSet(copy->operand(0)),
- /*source_base_index=*/{i},
- /*target_base_index=*/{i});
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(copy)) {
+ const ShapeIndex& index = pair.first;
+ if (index.empty()) {
+ // kCopy shallow copies and thus defines the top-level value so nothing to
+ // update.
+ continue;
+ }
+
+ HloValueSet& value_set = pair.second;
+ HloValueSet& operand_value_set = GetValueSet(copy->operand(0), index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
}
}
- return new_value_set;
+ return changed;
}
-InstructionValueSet HloDataflowAnalysis::RecomputeGetTupleElementValueSet(
- HloInstruction* gte) {
+bool HloDataflowAnalysis::UpdateGetTupleElementValueSet(HloInstruction* gte) {
CHECK_EQ(gte->opcode(), HloOpcode::kGetTupleElement);
- InstructionValueSet new_value_set(gte->shape());
- new_value_set.CopySubtreeFrom(GetInstructionValueSet(gte->operand(0)),
- /*source_base_index=*/{gte->tuple_index()},
- /*target_base_index=*/{});
- return new_value_set;
+ bool changed = false;
+ // The GetTupleElement instruction forwards the values from the specified
+ // tuple element.
+ for (auto& pair : GetInstructionValueSet(gte)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+
+ // The corresponding ShapeIndex of the operand is simply the GTE ShapeIndex
+ // with the tuple element number prefixed.
+ ShapeIndex operand_index = {gte->tuple_index()};
+ for (int64 i : index) {
+ operand_index.push_back(i);
+ }
+
+ HloValueSet& operand_value_set =
+ GetValueSet(gte->operand(0), operand_index);
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
+ return changed;
}
-InstructionValueSet HloDataflowAnalysis::RecomputeSelectValueSet(
- HloInstruction* select) {
+bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
+ CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
+ const CallGraphNode& call_graph_node =
+ call_graph_->GetNode(parameter->parent());
+
+ // Subcomputations called in a parallel context (eg, map) do not have dataflow
+ // from the caller operands.
+ if (call_graph_node.context() == CallContext::kParallel ||
+ call_graph_node.caller_callsites().empty()) {
+ return false;
+ }
+ CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
+
+ std::vector<const InstructionValueSet*> inputs;
+ bool called_from_while = false;
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kCall) {
+ // The operand values of a call instruction are forwarded to the
+ // respective parameter instruction of the subcomputation.
+ inputs.push_back(&GetInstructionValueSet(
+ callsite.instruction()->operand(parameter->parameter_number())));
+ } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
+ // In a while instruction, the while operand (ie, the init value) and the
+ // backedge are dataflow inputs to the parameter instruction. This is the
+ // case for parameters of both the body and condition computations.
+ CHECK_EQ(parameter->parameter_number(), 0);
+ inputs.push_back(
+ &GetInstructionValueSet(callsite.instruction()->operand(0)));
+ // If the parameter *is* the root, then don't consider it's current state
+ // (InstructionValueSet) as we are recomputing its current
+ // state. Otherwise, the parameter state would never be updated.
+ if (parameter !=
+ callsite.instruction()->while_body()->root_instruction()) {
+ inputs.push_back(&GetInstructionValueSet(
+ callsite.instruction()->while_body()->root_instruction()));
+ }
+ called_from_while = true;
+ } else {
+ LOG(FATAL) << "CallContext::kSequential computations should only be "
+ "called from call or while instructions";
+ }
+ }
+
+ if (ssa_form_ && called_from_while) {
+ UpdatePhiInputs(parameter, inputs);
+ return false;
+ } else {
+ return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
+ }
+}
+
+bool HloDataflowAnalysis::UpdateSelectValueSet(HloInstruction* select) {
CHECK_EQ(select->opcode(), HloOpcode::kSelect);
- std::vector<const InstructionValueSet*> inputs = {
- &GetInstructionValueSet(select->operand(1)),
- &GetInstructionValueSet(select->operand(2))};
// A phi value is not defined at a kSelect instruction because kSelect does
// not create a new value. Rather it forwards a value from its operands. This
// contrasts with kWhile instruction (which does define a phi value) which has
// in-place update semantics.
- InstructionValueSet new_value_set = InstructionValueSet::Union(inputs);
- *new_value_set.mutable_element(/*index=*/{}) =
- GetInstructionValueSet(select).element(/*index=*/{});
- return new_value_set;
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(select)) {
+ const ShapeIndex& index = pair.first;
+ if (index.empty()) {
+ // kSelect copies (not forwards) the top-level value.
+ continue;
+ }
+ HloValueSet& value_set = pair.second;
+ changed |=
+ value_set.AssignUnionOf({&GetValueSet(select->operand(1), index),
+ &GetValueSet(select->operand(2), index)});
+ }
+ return changed;
}
-InstructionValueSet HloDataflowAnalysis::RecomputeTupleValueSet(
- HloInstruction* tuple) {
+bool HloDataflowAnalysis::UpdateTupleValueSet(HloInstruction* tuple) {
CHECK_EQ(tuple->opcode(), HloOpcode::kTuple);
- InstructionValueSet new_value_set(tuple->shape());
- *new_value_set.mutable_element(/*index=*/{}) =
- GetInstructionValueSet(tuple).element(/*index=*/{});
+ bool changed = false;
for (int64 i = 0; i < tuple->operands().size(); ++i) {
- new_value_set.CopySubtreeFrom(GetInstructionValueSet(tuple->operand(i)),
- /*source_base_index=*/{},
- /*target_base_index=*/{i});
+ // Copy the value set(s) of each operand into the respective position in the
+ // kTuple instruction's value sets.
+ for (auto& pair : GetInstructionValueSet(tuple->operand(i))) {
+ const ShapeIndex& operand_index = pair.first;
+ HloValueSet& operand_value_set = pair.second;
+
+ ShapeIndex index = {i};
+ for (int64 op_index : operand_index) {
+ index.push_back(op_index);
+ }
+ HloValueSet& value_set = GetValueSet(tuple, index);
+
+ if (value_set != operand_value_set) {
+ value_set = operand_value_set;
+ changed = true;
+ }
+ }
}
- return new_value_set;
+ return changed;
}
-InstructionValueSet HloDataflowAnalysis::RecomputeWhileValueSet(
- HloInstruction* xla_while) {
+bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
CHECK_EQ(xla_while->opcode(), HloOpcode::kWhile);
std::vector<const InstructionValueSet*> inputs = {
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
&GetInstructionValueSet(xla_while->operand(0))};
if (ssa_form_) {
- return Phi(xla_while, inputs);
+ UpdatePhiInputs(xla_while, inputs);
+ return false;
} else {
- return InstructionValueSet::Union(inputs);
+ return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
}
}
-void HloDataflowAnalysis::UpdateInstructionValueSet(
+bool HloDataflowAnalysis::UpdateInstructionValueSet(
HloInstruction* instruction) {
// Recompute from operands.
- InstructionValueSet& value_set = GetInstructionValueSet(instruction);
switch (instruction->opcode()) {
case HloOpcode::kBitcast:
- value_set = RecomputeBitcastValueSet(instruction);
- break;
+ return UpdateBitcastValueSet(instruction);
case HloOpcode::kCopy:
- value_set = RecomputeCopyValueSet(instruction);
- break;
+ return UpdateCopyValueSet(instruction);
case HloOpcode::kGetTupleElement:
- value_set = RecomputeGetTupleElementValueSet(instruction);
- break;
+ return UpdateGetTupleElementValueSet(instruction);
case HloOpcode::kSelect:
- value_set = RecomputeSelectValueSet(instruction);
- break;
+ return UpdateSelectValueSet(instruction);
case HloOpcode::kTuple:
- value_set = RecomputeTupleValueSet(instruction);
- break;
+ return UpdateTupleValueSet(instruction);
case HloOpcode::kParameter:
- value_set = RecomputeParameterValueSet(instruction);
- break;
+ return UpdateParameterValueSet(instruction);
case HloOpcode::kCall:
- // The output of a kCall instruction is exactly the output of the root of
- // the subcomputation.
- value_set =
- GetInstructionValueSet(instruction->to_apply()->root_instruction());
- break;
+ return UpdateCallValueSet(instruction);
case HloOpcode::kWhile:
- value_set = RecomputeWhileValueSet(instruction);
- break;
+ return UpdateWhileValueSet(instruction);
default:
// Instruction does not forward HloValues (it defines all values in its
// output). No update is necessary.
- return;
+ return false;
}
}
@@ -411,11 +487,38 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
- // Save old value for recomputing uses and live out.
- InstructionValueSet old_value = GetInstructionValueSet(instruction);
- UpdateInstructionValueSet(instruction);
+ // The updating of the instruction value set below in
+ // UpdateInstructionValueSet does not update HloValue::positions(). To
+ // perform the positions() update remove all positions in 'instruction' from
+ // the HloValues in 'instruction's value set prior to the update, then after
+ // the update add the new positions back in. There is likely a more
+ // efficient way of doing this.
+ for (auto& pair : GetInstructionValueSet(instruction)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction() != instruction) {
+ // Use GetValue for a non-const HloValue reference.
+ GetValue(value->id()).RemovePosition(instruction, index);
+ }
+ }
+ }
+
+ bool changed = UpdateInstructionValueSet(instruction);
+
+ // Add the positions back in.
+ for (auto& pair : GetInstructionValueSet(instruction)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction() != instruction) {
+ // Use GetValue for a non-const HloValue reference.
+ GetValue(value->id()).AddPosition(instruction, index);
+ }
+ }
+ }
- if (GetInstructionValueSet(instruction) == old_value) {
+ if (!changed) {
// No change to the instruction's value set.
VLOG(4) << "No change.";
continue;
@@ -423,7 +526,6 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
VLOG(4) << "New value set for " << instruction->name() << ": "
<< GetInstructionValueSet(instruction);
- VLOG(4) << "Previously: " << old_value;
// Instruction value was updated. Add users to work list.
for (HloInstruction* user : instruction->users()) {
@@ -458,57 +560,6 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
}
}
}
-
- // Update uses. First clear all of the old uses at the particular
- // operands. Then add the new uses. There may be overlap between the old
- // uses and new uses.
- UpdatePositionsOfValuesAt(instruction, GetInstructionValueSet(instruction),
- &old_value);
- }
-}
-
-InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet(
- HloInstruction* parameter) {
- CHECK_EQ(parameter->opcode(), HloOpcode::kParameter);
- const CallGraphNode& call_graph_node =
- call_graph_->GetNode(parameter->parent());
-
- // Subcomputations called in a parallel context (eg, map) do not have dataflow
- // from the caller operands.
- if (call_graph_node.context() == CallContext::kParallel ||
- call_graph_node.caller_callsites().empty()) {
- return GetInstructionValueSet(parameter);
- }
- CHECK_EQ(call_graph_node.context(), CallContext::kSequential);
-
- std::vector<const InstructionValueSet*> inputs;
- bool called_from_while = false;
- for (const CallSite& callsite : call_graph_node.caller_callsites()) {
- if (callsite.instruction()->opcode() == HloOpcode::kCall) {
- // The operand values of a call instruction are forwarded to the
- // respective parameter instruction of the subcomputation.
- inputs.push_back(&GetInstructionValueSet(
- callsite.instruction()->operand(parameter->parameter_number())));
- } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
- // In a while instruction, the while operand (ie, the init value) and the
- // backedge are dataflow inputs to the parameter instruction. This is the
- // case for parameters of both the body and condition computations.
- CHECK_EQ(parameter->parameter_number(), 0);
- inputs.push_back(
- &GetInstructionValueSet(callsite.instruction()->operand(0)));
- inputs.push_back(&GetInstructionValueSet(
- callsite.instruction()->while_body()->root_instruction()));
- called_from_while = true;
- } else {
- LOG(FATAL) << "CallContext::kSequential computations should only be "
- "called from call or while instructions";
- }
- }
-
- if (ssa_form_ && called_from_while) {
- return Phi(parameter, inputs);
- } else {
- return InstructionValueSet::Union(inputs);
}
}
@@ -523,10 +574,26 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
+ // Gather the values to create before creating them. This is done because we
+ // want to allocate the vector of values only once so references to elements
+ // are stable.
+ struct ValueToCreate {
+ HloInstruction* instruction;
+ ShapeIndex index;
+ bool is_phi;
+ };
+ std::vector<ValueToCreate> values_to_create;
+
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
const CallGraphNode& call_graph_node =
call_graph_->GetNode(computation.get());
+ bool called_from_while = std::any_of(
+ call_graph_node.caller_callsites().begin(),
+ call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
+ return cs.instruction()->opcode() == HloOpcode::kWhile;
+ });
+
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
// Create an empty shape tree.
@@ -536,21 +603,20 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Lambda to set the value set to define all values in the output of the
// instruction.
- auto define_all_values = [this, &instruction]() {
- GetInstructionValueSet(instruction.get())
- .ForEachMutableElement([this, &instruction](
- const ShapeIndex& index,
- HloValueSet* value_set) {
- *value_set = HloValueSet({NewHloValue(instruction.get(), index)});
- });
+ auto define_all_values = [this, &instruction,
+ &values_to_create](bool is_phi = false) {
+ for (auto& pair : GetInstructionValueSet(instruction.get())) {
+ const ShapeIndex& index = pair.first;
+ values_to_create.push_back({instruction.get(), index, is_phi});
+ }
};
// Lambda to set the value set to define only the top-level buffer in the
// output of the instruction. Any other values flow from the operands of
// the instruction (or from cross-computation dataflow).
- auto define_top_level_only = [this, &instruction]() {
- GetValueSet(instruction.get(), /*index=*/{}) =
- HloValueSet({NewHloValue(instruction.get(), /*index=*/{})});
+ auto define_top_level_only = [this, &instruction, &values_to_create]() {
+ values_to_create.push_back(
+ {instruction.get(), /*index=*/{}, /*is_phi=*/false});
};
switch (instruction->opcode()) {
@@ -559,21 +625,18 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values();
}
break;
- case HloOpcode::kCall:
case HloOpcode::kWhile:
+ if (ssa_form_) {
+ define_all_values(/*is_phi=*/true);
+ }
+ break;
+ case HloOpcode::kCall:
case HloOpcode::kGetTupleElement:
// These instructions define no values. The values in their output
// flow from their operands or from cross computation dataflow.
break;
case HloOpcode::kParameter:
- if (call_graph_node.caller_callsites().empty() ||
- call_graph_node.context() == CallContext::kParallel) {
- // Parameters of computations called in a parallel context (eg, map
- // and reduce) as well as parameters of dead computations define all
- // values in their output. Otherwise the values of the parameter
- // come from the caller (eg, operands to the kCall instruction).
- define_all_values();
- } else if (call_graph_node.context() == CallContext::kBoth) {
+ if (call_graph_node.context() == CallContext::kBoth) {
// We do not support a subcomputation that is called from both a
// parallel and sequential context. In this case, the parameter
// would both define a value and propagate a value from its
@@ -584,6 +647,18 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
"sequential (eg, kCall) context",
computation->name().c_str());
}
+ if (call_graph_node.caller_callsites().empty() ||
+ call_graph_node.context() == CallContext::kParallel) {
+ // Parameters of computations called in a parallel context (eg, map
+ // and reduce) as well as parameters of dead computations define all
+ // values in their output. Otherwise the values of the parameter
+ // come from the caller (eg, operands to the kCall instruction).
+ define_all_values();
+ } else if (call_graph_node.context() == CallContext::kSequential &&
+ called_from_while && ssa_form_) {
+ // Parameters of while bodies and conditions are phis.
+ define_all_values(/*is_phi=*/true);
+ }
break;
case HloOpcode::kCopy:
case HloOpcode::kSelect:
@@ -596,8 +671,19 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
define_all_values();
break;
}
- UpdatePositionsOfValuesAt(instruction.get(),
- GetInstructionValueSet(instruction.get()));
+ }
+ }
+
+ // Reserve the vector ahead of time so references to elements are stable.
+ values_.reserve(values_to_create.size());
+ for (int64 i = 0; i < values_to_create.size(); ++i) {
+ const ValueToCreate& to_create = values_to_create[i];
+ values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
+ to_create.is_phi);
+ const HloValue& value = values_.back();
+ GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
+ if (value.is_phi()) {
+ phi_inputs_[&value] = {};
}
}
return Status::OK();
@@ -769,8 +855,118 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
}
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
- VLOG(1) << dataflow_analysis->ToString();
+ TF_DCHECK_OK(dataflow_analysis->Verify());
+
+ XLA_VLOG_LINES(1, dataflow_analysis->ToString());
+
return std::move(dataflow_analysis);
}
+Status HloDataflowAnalysis::Verify() const {
+ // Verify each HloValue appears in the value sets that the value's positions()
+ // indicate.
+ for (const HloValue& value : values()) {
+ for (const HloPosition& position : value.positions()) {
+ const HloValueSet& value_set = GetValueSet(position);
+ TF_RET_CHECK(std::find(value_set.values().begin(),
+ value_set.values().end(),
+ &value) != value_set.values().end())
+ << "Value set at position " << position << " does not contain value "
+ << value.ToShortString();
+ }
+ }
+
+ // For each value in each value set, verify that the value set's position
+ // appears in the value's positions().
+ for (const auto& computation : module_->computations()) {
+ for (const auto& instruction : computation->instructions()) {
+ for (const auto& pair : GetInstructionValueSet(instruction.get())) {
+ const ShapeIndex& index = pair.first;
+ const HloValueSet& value_set = pair.second;
+ const HloPosition position{instruction.get(), index};
+ for (const HloValue* value : value_set.values()) {
+ TF_RET_CHECK(std::find(value->positions().begin(),
+ value->positions().end(),
+ position) != value->positions().end())
+ << "Value set at position " << position
+ << " unexpectedly contains value " << value->ToShortString();
+ }
+ }
+ }
+ }
+
+ return Status::OK();
+}
+
+Status HloDataflowAnalysis::VerifyAgainstReference() const {
+ TF_RETURN_IF_ERROR(Verify());
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
+ Run(module_, ssa_form_, bitcast_defines_value_));
+ TF_RETURN_IF_ERROR(reference->Verify());
+
+ VLOG(2) << "This analysis:";
+ XLA_VLOG_LINES(2, ToString());
+ VLOG(2) << "Reference:";
+ XLA_VLOG_LINES(2, reference->ToString());
+
+ // Verify value sets in each position are identical.
+ for (const auto& computation : module_->computations()) {
+ for (const auto& instruction : computation->instructions()) {
+ for (const auto& pair : GetInstructionValueSet(instruction.get())) {
+ const ShapeIndex& index = pair.first;
+ const HloValueSet& value_set = pair.second;
+ const HloValueSet& reference_value_set =
+ reference->GetValueSet(instruction.get(), index);
+
+ auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
+ return std::find_if(vset.values().begin(), vset.values().end(),
+ [&v](const HloValue* w) { return *w == v; }) !=
+ vset.values().end();
+ };
+
+ for (const HloValue* value : value_set.values()) {
+ TF_RET_CHECK(value_in_set(*value, reference_value_set))
+ << "Value " << value->ToShortString()
+ << " does not exist in reference";
+ }
+ for (const HloValue* reference_value : reference_value_set.values()) {
+ TF_RET_CHECK(value_in_set(*reference_value, value_set))
+ << "Value " << reference_value->ToShortString()
+ << " only exists in reference";
+ }
+ }
+ }
+ }
+
+ // Verify all phis resolve identically and uses are identical.
+ for (const HloValue& value : values()) {
+ const HloValue& reference_value = reference->GetValueDefinedAt(
+ value.defining_instruction(), value.defining_index());
+ TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
+ if (value.is_phi()) {
+ const HloValue* resolved_value = ResolvePhi(value);
+ const HloValue* reference_resolved_value =
+ reference->ResolvePhi(reference_value);
+ if (resolved_value == nullptr) {
+ TF_RET_CHECK(reference_resolved_value == nullptr);
+ } else {
+ TF_RET_CHECK(reference_resolved_value != nullptr);
+ TF_RET_CHECK(*reference_resolved_value == *resolved_value);
+ }
+ }
+
+ for (const HloUse& use : value.uses()) {
+ TF_RET_CHECK(std::find(reference_value.uses().begin(),
+ reference_value.uses().end(),
+ use) != reference_value.uses().end());
+ }
+ for (const HloUse& reference_use : reference_value.uses()) {
+ TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
+ reference_use) != value.uses().end());
+ }
+ }
+ return Status::OK();
+}
+
} // namespace xla