aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-06-19 18:38:20 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-06-19 18:41:40 -0700
commit5b6a203c5c759656b2b7018271219916ddd85cb6 (patch)
treee4ba01c8a30d2066ee7f05a147638e5d0cbe246b /tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
parenta36488d812780e78f869a3eb2b692cf3c236f1cc (diff)
[XLA] Add live range interference querying to dataflow analysis.
Add method MayInterfere to HloDataflowAnalysis which returns whether the live ranges of two values interfere. This will replace buffer_liveness.cc. The cl includes a few related changes: (1) HloOrdering: Apply an order to the condition and body computations. Specifically, for the purposes of HLO ordering the condition is ordered before the body. This ensures that the live ranges of values in the condition do not interfere with the live ranges in the body. (2) Add a Dominates method to CallGraph for determining whether a computation dominates another in the call graph. (3) Tightened the definition of "use" in the dataflow analysis. Now an instruction which passes through a value without reading it is no longer considered a use of the value. This new definition is reflected in the HloUse objects returned by HloValue::uses(). PiperOrigin-RevId: 159509724
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc267
1 files changed, 239 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index d1b8725644..7e951721ba 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -76,7 +76,8 @@ HloValue::HloValue(HloValue::Id id, HloInstruction* instruction,
}
bool HloValue::operator==(const HloValue& other) const {
- bool equal = instruction() == other.instruction() && index() == other.index();
+ bool equal = defining_instruction() == other.defining_instruction() &&
+ defining_index() == other.defining_index();
// If the values are equal they most both be phi (or non phi).
CHECK(!(equal && is_phi() != other.is_phi()));
return equal;
@@ -87,10 +88,11 @@ bool HloValue::operator!=(const HloValue& other) const {
}
string HloValue::ToShortString() const {
- string index_str =
- ShapeUtil::IsTuple(instruction()->shape()) ? index().ToString() : "";
- return StrCat(is_phi_ ? "PHI " : "", instruction()->FullyQualifiedName(),
- index_str);
+ string index_str = ShapeUtil::IsTuple(defining_instruction()->shape())
+ ? defining_index().ToString()
+ : "";
+ return StrCat(is_phi_ ? "PHI " : "",
+ defining_instruction()->FullyQualifiedName(), index_str);
}
string HloValue::ToString(int indent) const {
@@ -106,6 +108,50 @@ string HloValue::ToString(int indent) const {
return out;
}
+namespace {
+
+// Returns true if the instruction 'user' may use the value at the given
+// ShapeIndex in the given operand. Generally, instruction which pass through
+// values transparently without reading the value are not considered to use the
+// value.
+bool MayUseOperandValue(int64 operand_number, const ShapeIndex& index,
+ const HloInstruction* user) {
+ switch (user->opcode()) {
+ case HloOpcode::kGetTupleElement:
+ case HloOpcode::kCopy:
+ // These instructions only access the top-level values of their
+ // operand. Non-top-level (nested) values are passed through
+ // transparently.
+ CHECK_EQ(operand_number, 0);
+ return index.empty();
+ case HloOpcode::kSelect:
+ // Select does not use any nested elements of its selected-from operands
+ // (operand 1 and 2)
+ CHECK_GE(operand_number, 0);
+ CHECK_LE(operand_number, 2);
+ return operand_number == 0 || index.empty();
+
+ case HloOpcode::kCall:
+ case HloOpcode::kTuple:
+ // These instructions always pass through their operands transparently.
+ return false;
+
+ case HloOpcode::kWhile:
+ // Though the while instructions passes through its operands, we return
+ // true because in SSA form there may be a Phi at the parameter of the
+ // while which is considered a use of its incoming value because the Phi
+ // input values are not passed through into the body computation. Because
+ // this function is used in both SSA and non-SSA forms of the analysis
+ // conservatively return true.
+ return true;
+
+ default:
+ return true;
+ }
+}
+
+} // namespace
+
void HloValue::AddLocation(HloInstruction* instruction,
const ShapeIndex& index) {
// The given location should not already exist in locations_.
@@ -118,7 +164,7 @@ void HloValue::AddLocation(HloInstruction* instruction,
// Update uses.
for (HloInstruction* user : instruction->users()) {
for (int64 operand_number : user->OperandIndices(instruction)) {
- if (!DoesNotUseOperandBuffer(instruction, index, user)) {
+ if (MayUseOperandValue(operand_number, index, user)) {
for (const HloUse& use : uses_) {
// Verify that this use does not already exist.
DCHECK(!(use.instruction == user &&
@@ -136,12 +182,16 @@ void HloValue::AddLocation(HloInstruction* instruction,
if (instruction == module.entry_computation()->root_instruction()) {
live_out_of_module_ = true;
}
+
+ if (instruction == instruction->parent()->root_instruction()) {
+ live_out_of_computation_ = true;
+ }
}
void HloValue::RemoveLocation(HloInstruction* instruction,
const ShapeIndex& index) {
// The defining location cannot be removed.
- CHECK(!(instruction == this->instruction() && index == this->index()));
+ CHECK(!(instruction == defining_instruction() && index == defining_index()));
int64 size_before = locations_.size();
locations_.erase(
@@ -163,19 +213,27 @@ void HloValue::RemoveLocation(HloInstruction* instruction,
}),
uses_.end());
- const HloModule& module = *instruction->parent()->parent();
- if (instruction == module.entry_computation()->root_instruction()) {
- // Value has been removed from a location in the entry root instruction.
- // Check if the value is still live out of the module by walking all
- // remaining locations.
- live_out_of_module_ = false;
+ // Returns whether this value is contained in the given instruction's output.
+ auto is_contained_in = [this](const HloInstruction* instruction) {
for (const HloLocation& location : locations()) {
- if (location.instruction ==
- module.entry_computation()->root_instruction()) {
- live_out_of_module_ = true;
- break;
+ if (location.instruction == instruction) {
+ return true;
}
}
+ return false;
+ };
+
+ const HloModule& module = *instruction->parent()->parent();
+ if (instruction == module.entry_computation()->root_instruction()) {
+ // Value has been removed from a location in the entry root instruction.
+ live_out_of_module_ =
+ is_contained_in(module.entry_computation()->root_instruction());
+ }
+ if (instruction == defining_instruction()->parent()->root_instruction()) {
+ // Value has been removed from the root of the computation the value has
+ // been defined in.
+ live_out_of_computation_ =
+ is_contained_in(defining_instruction()->parent()->root_instruction());
}
}
@@ -259,7 +317,8 @@ bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
if (value_set.value_ids().size() != 1) {
return false;
}
- return GetValue(value_set.GetUniqueValueId()).instruction() == instruction;
+ return GetValue(value_set.GetUniqueValueId()).defining_instruction() ==
+ instruction;
}
const HloValue& HloDataflowAnalysis::GetValueDefinedAt(
@@ -468,8 +527,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
}
// Don't remove the defining location of the value.
HloValue& value = GetValue(value_id);
- if (instruction == value.instruction()) {
- CHECK_EQ(index, value.index());
+ if (instruction == value.defining_instruction()) {
+ CHECK_EQ(index, value.defining_index());
} else {
value.RemoveLocation(instruction, index);
}
@@ -482,8 +541,8 @@ void HloDataflowAnalysis::UpdateLocationsOfValuesAt(
const HloValueSet& value_set) {
for (HloValue::Id value_id : value_set.value_ids()) {
HloValue& value = GetValue(value_id);
- if (instruction == value.instruction()) {
- CHECK_EQ(index, value.index());
+ if (instruction == value.defining_instruction()) {
+ CHECK_EQ(index, value.defining_index());
} else {
value.AddLocation(instruction, index);
}
@@ -694,15 +753,24 @@ InstructionValueSet HloDataflowAnalysis::RecomputeParameterValueSet(
std::vector<const InstructionValueSet*> inputs;
bool called_from_while = false;
for (const CallSite& callsite : call_graph_node.caller_callsites()) {
- inputs.push_back(&GetInstructionValueSet(
- callsite.instruction()->operand(parameter->parameter_number())));
- if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
- // In a while instruction, the backedge is also a dataflow input to the
- // parameter instruction. This code covers the case where the parameter is
- // in the while body or the parameter is in the while condition.
+ if (callsite.instruction()->opcode() == HloOpcode::kCall) {
+ // The operand values of a call instruction are forwarded to the
+ // respective parameter instruction of the subcomputation.
+ inputs.push_back(&GetInstructionValueSet(
+ callsite.instruction()->operand(parameter->parameter_number())));
+ } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
+ // In a while instruction, the while operand (ie, the init value) and the
+ // backedge are dataflow inputs to the parameter instruction. This is the
+ // case for parameters of both the body and condition computations.
+ CHECK_EQ(parameter->parameter_number(), 0);
+ inputs.push_back(
+ &GetInstructionValueSet(callsite.instruction()->operand(0)));
inputs.push_back(&GetInstructionValueSet(
callsite.instruction()->while_body()->root_instruction()));
called_from_while = true;
+ } else {
+ LOG(FATAL) << "CallContext::kSequential computations should only be "
+ "called from call or while instructions";
}
}
@@ -804,6 +872,149 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
return Status::OK();
}
+bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
+ const HloOrdering& ordering) const {
+ // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
+ // is live into the module.
+ if (b.defining_instruction()->parent() == module_->entry_computation() &&
+ b.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ return false;
+ }
+
+ // Phi values require special handling. Because XLA does not have a phi
+ // instruction, the definition instruction of the phis values are
+ // placeholders: either the subcomputation parameter (body or condition) or
+ // the while instruction. However, the program point where these values are
+ // logically defined does not necessarily coincide exactly with program point
+ // of these place-holder instructions. So we explicitly define the following
+ // order for phi values:
+ //
+ // body/condition parameter phi:
+ // Defined before all values defined in its computation excepting other
+ // phis.
+ //
+ // while phi:
+ // defined after all values defined in the condition or body.
+ //
+ auto is_body_or_condition_phi = [](const HloValue& v) {
+ return v.is_phi() &&
+ v.defining_instruction()->opcode() == HloOpcode::kParameter;
+ };
+ if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
+ call_graph_->InstructionIsNestedIn(b.defining_instruction(),
+ a.defining_instruction()->parent())) {
+ return true;
+ }
+ if (is_body_or_condition_phi(b) &&
+ call_graph_->InstructionIsNestedIn(a.defining_instruction(),
+ b.defining_instruction()->parent())) {
+ return false;
+ }
+
+ // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
+ // executes before 'b'.
+ if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
+ (call_graph_->InstructionIsNestedIn(
+ a.defining_instruction(), b.defining_instruction()->while_body()) ||
+ call_graph_->InstructionIsNestedIn(
+ a.defining_instruction(),
+ b.defining_instruction()->while_condition()))) {
+ return true;
+ }
+
+ return ordering.ExecutesBefore(a.defining_instruction(),
+ b.defining_instruction());
+}
+
+bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
+ const HloUse& use, const HloValue& value,
+ const HloOrdering& ordering) const {
+ if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
+ return true;
+ }
+
+ // If the use is at the instruction where the value is defined, then the use
+ // is before the def if the instruction allows buffer sharing (in place
+ // computation).
+ if (use.instruction == value.defining_instruction() &&
+ CanShareOperandBufferWithUser(
+ use.instruction->mutable_operand(use.operand_number),
+ use.operand_index, value.defining_instruction(),
+ value.defining_index())) {
+ return true;
+ }
+
+ // The use at a while is an input to a phi, and logically occurs before values
+ // are defined in the body or condition computations.
+ if (use.instruction->opcode() == HloOpcode::kWhile) {
+ const HloInstruction* xla_while = use.instruction;
+ if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
+ xla_while->while_body()) ||
+ call_graph_->InstructionIsNestedIn(value.defining_instruction(),
+ xla_while->while_condition())) {
+ return true;
+ }
+ }
+
+ // Similarly if the value is defined at a while, it logically occurs after any
+ // uses in the body or condition computations.
+ if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
+ CHECK(ssa_form_);
+ const HloInstruction* xla_while = value.defining_instruction();
+ if (call_graph_->InstructionIsNestedIn(use.instruction,
+ xla_while->while_body()) ||
+ call_graph_->InstructionIsNestedIn(use.instruction,
+ xla_while->while_condition())) {
+ return true;
+ }
+ }
+ return false;
+}
+
+bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
+ const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
+ VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
+ << ", b = " << b.ToShortString() << ")";
+ if (!IsDefinedBefore(a, b, ordering)) {
+ VLOG(4) << "a not defined before b";
+ return false;
+ }
+
+ // Live-out values from the module can never have ranges strictly before any
+ // other value.
+ if (a.live_out_of_module()) {
+ VLOG(4) << "a is live out of module";
+ return false;
+ }
+
+ // Live-out values of computations can never have ranges strictly before any
+ // other value in the computation (including values nested in
+ // subcomputations).
+ if (a.live_out_of_computation() &&
+ call_graph_->InstructionIsNestedIn(b.defining_instruction(),
+ a.defining_instruction()->parent())) {
+ VLOG(4) << "a is live out of computation containing b";
+ return false;
+ }
+
+ // All uses of 'a' must be before 'b' is defined.
+ for (const HloUse& use : a.uses()) {
+ if (!UseIsBeforeValueDefinition(use, b, ordering)) {
+ VLOG(4) << "use of a (" << use << ") not before b is defined";
+ return false;
+ }
+ }
+
+ return true;
+}
+
+bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
+ const HloOrdering& ordering) const {
+ // Buffers without disjoint liveness may interfere.
+ return !LiveRangeStrictlyBefore(a, b, ordering) &&
+ !LiveRangeStrictlyBefore(b, a, ordering);
+}
+
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
HloModule* module, bool ssa_form, bool bitcast_defines_value) {