aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-09-01 09:17:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-09-01 09:21:31 -0700
commit73d796423348347702d43b498257f34e41fba367 (patch)
treed9be8036c53d39d3ab4abddd7793724fc5dc52d1
parent6e8d0c632dea30758c7cc343decdf8ab7956e59d (diff)
Rollback update-ability of dataflow and alias analysis added in cl/164923041 and cl/64778750. It did not scale as intended to large graphs when used in copy insertion. This change also includes some simplification and performance improvements to dataflow and alias analysis. Also add some value-ordering tests to HloOrderingTest using dataflow analysis to generate values.
PiperOrigin-RevId: 167283460
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc601
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h51
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc144
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc16
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc588
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.h96
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc328
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc89
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.h3
12 files changed, 576 insertions, 1362 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 9d4e7fc254..610c611eee 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -849,6 +849,7 @@ cc_test(
srcs = ["hlo_ordering_test.cc"],
deps = [
":hlo",
+ ":hlo_dataflow_analysis",
":hlo_ordering",
":hlo_scheduling",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 0beea42379..3dd8ac6dc5 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -37,6 +37,230 @@ namespace xla {
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
+// Data structure used to construct the alias analysis. Thrown away after alias
+// analysis is complete. This data structure keeps track of which sets of
+// HloValues must be in the same HloBuffer. This is maintained as a map from a
+// buffer identifier (BufferNumber) to set of HLoValues.
+//
+// Initially each value is its own buffer. In MergeAliasedBuffers, sets of
+// values which must share the same buffer are merged together. The end result
+// is a partitioning of all HloValues into sets where each set needs its own
+// HloBuffer. By performing this analysis without constructing HloBuffers on the
+// fly, we can after-the-fact construct a vector of contiguously numbered
+// HloBuffers after the buffer requirement has been determined.
+class BufferValueMap {
+ public:
+ // A unique identifier for a set of colocated values which must share the same
+ // buffer. This is not necessarily the same as the HloBuffer::Id which will
+ // ultimately contain the values. The reason is that HloBuffer::Id's are
+ // contiguous, while BufferNumbers may not be. BufferNumbers may not be
+ // dense because buffers may be created and destroyed during the analysis
+ // construction process.
+ using BufferNumber = int64;
+
+ explicit BufferValueMap(const HloDataflowAnalysis& dataflow)
+ : dataflow_(dataflow) {
+ buffers_.reserve(dataflow_.values().size());
+ value_to_buffer_number_.reserve(dataflow_.values().size());
+ for (const HloValue* value : dataflow_.values()) {
+ BufferNumber buffer_number = next_buffer_number_++;
+ buffers_[buffer_number].insert(value);
+ value_to_buffer_number_[value] = buffer_number;
+ }
+ }
+
+ // Merge together sets of HloValues which must be in the same HloBuffer
+ // because of aliasing rules (eg, in-place kWhile instruction).
+ void MergeAliasedBuffers() {
+ for (const HloValue* value : dataflow_.values()) {
+ VLOG(3) << "Merging colocated values, value: " << value->ToShortString();
+
+ // Gather the set of buffers with aliasing rules (eg, kWhile) which this
+ // value must be contained in.
+ std::vector<BufferNumber> aliased_buffers = ComputeAliasedBuffers(*value);
+
+ BufferNumber current_buffer = value_to_buffer_number_.at(value);
+ if (aliased_buffers.empty()) {
+ // The buffer containing 'value' aliases no other buffers. If the buffer
+ // containing 'value' already only contains 'value', then no change is
+ // necessary. If the buffer containing 'value' does contain other
+ // values, then remove 'value' from the buffer and create a new buffer
+ // containing only 'value'
+ if (buffers_.at(current_buffer).size() == 1) {
+ CHECK_EQ(*buffers_.at(current_buffer).begin(), value);
+ } else {
+ MoveValueToNewBuffer(*value);
+ }
+ } else {
+ // If multiple buffers are aliased merge these buffers together into a
+ // single buffer (arbitrarily chosen as the first buffer in the vector).
+ if (aliased_buffers.size() > 1) {
+ for (int64 i = 1; i < aliased_buffers.size(); ++i) {
+ MergeBuffers(/*from=*/aliased_buffers[i],
+ /*to=*/aliased_buffers[0]);
+ }
+ }
+ BufferNumber new_buffer = aliased_buffers[0];
+ if (current_buffer != new_buffer) {
+ MoveValueToBuffer(*value, new_buffer);
+ }
+ }
+ }
+ }
+
+ // Compute and return a sorted vector of all BufferNumbers. Can be used to
+ // iterate through all buffers stabily.
+ std::vector<BufferNumber> ComputeSortedBufferNumbers() const {
+ std::vector<BufferNumber> buffer_numbers;
+ for (const auto& pair : buffers_) {
+ buffer_numbers.push_back(pair.first);
+ }
+ std::sort(buffer_numbers.begin(), buffer_numbers.end());
+ return buffer_numbers;
+ }
+
+ // Return a set of all the values in the given buffer.
+ const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
+ BufferNumber buffer_number) const {
+ return buffers_.at(buffer_number);
+ }
+
+ private:
+ // Create a new buffer.
+ void NewBuffer(const HloValue& value) {
+ BufferNumber buffer_number = next_buffer_number_++;
+ buffers_[buffer_number].insert(&value);
+ value_to_buffer_number_[&value] = buffer_number;
+ }
+
+ // Move the given value into a new buffer containing only the value.
+ void MoveValueToNewBuffer(const HloValue& value) {
+ BufferNumber new_buffer_number = next_buffer_number_++;
+ buffers_[new_buffer_number];
+ MoveValueToBuffer(value, new_buffer_number);
+ }
+
+ // Move the given value into the given buffer.
+ void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
+ BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
+ buffers_.at(old_buffer_number).erase(&value);
+ if (buffers_.at(old_buffer_number).empty()) {
+ buffers_.erase(old_buffer_number);
+ }
+
+ buffers_.at(buffer_number).insert(&value);
+ value_to_buffer_number_.at(&value) = buffer_number;
+ }
+
+ // Merge the buffer 'from' into the buffer 'to'.
+ void MergeBuffers(BufferNumber from, BufferNumber to) {
+ auto& from_value_set = buffers_.at(from);
+ buffers_.at(to).insert(from_value_set.begin(), from_value_set.end());
+ // NOTE: using a union-find algorithm to hold the colocated values might be
+ // faster.
+ for (const HloValue* value : from_value_set) {
+ value_to_buffer_number_.at(value) = to;
+ }
+ buffers_.erase(from);
+ }
+
+ BufferNumber GetBufferForValue(const HloValue& value) {
+ return value_to_buffer_number_.at(&value);
+ }
+
+ // Compute and return a vector of buffers that the given value must be
+ // contained in due to HLO aliasing rules.
+ std::vector<BufferNumber> ComputeAliasedBuffers(const HloValue& value) {
+ // Value is init of a while (use is while).
+ std::vector<BufferNumber> aliased_buffers;
+ for (const HloUse& use : value.uses()) {
+ VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
+ if (use.instruction->opcode() == HloOpcode::kWhile) {
+ // Determine the while value that this shares a buffer with.
+ const HloValue& while_value =
+ dataflow_.GetUniqueValueAt(use.instruction, use.operand_index);
+ aliased_buffers.push_back(GetBufferForValue(while_value));
+ VLOG(3) << " value is init value to a while; must share buffer with "
+ "while value "
+ << while_value.ToShortString();
+ }
+ }
+
+ // Value is a parameter of a while body/condition.
+ if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ const HloComputation* computation =
+ value.defining_instruction()->parent();
+ const CallGraphNode& call_graph_node =
+ dataflow_.call_graph().GetNode(computation);
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
+ // Call graph must have been flattened.
+ CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
+
+ const HloValue& while_value = dataflow_.GetUniqueValueAt(
+ callsite.instruction(), value.defining_index());
+ VLOG(3) << " value is parameter value of the body or condition of a "
+ "while; must share buffer with while value "
+ << while_value.ToShortString();
+ aliased_buffers.push_back(GetBufferForValue(while_value));
+ }
+ }
+ }
+
+ // Value is the root of a while body.
+ for (const HloPosition& position : value.positions()) {
+ const HloComputation* computation = position.instruction->parent();
+ const CallGraphNode& call_graph_node =
+ dataflow_.call_graph().GetNode(computation);
+ if (position.instruction == computation->root_instruction()) {
+ for (const CallSite& callsite : call_graph_node.caller_callsites()) {
+ if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
+ callsite.instruction()->while_body() == computation) {
+ // Call graph must have been flattened.
+ CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
+
+ const HloValue& while_value = dataflow_.GetUniqueValueAt(
+ callsite.instruction(), position.index);
+ VLOG(3) << " value is root the body computation of a while; must "
+ "share buffer with while value "
+ << while_value.ToShortString();
+ aliased_buffers.push_back(GetBufferForValue(while_value));
+ }
+ }
+ }
+ }
+
+ // Value is the output of the while instruction itself.
+ if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
+ VLOG(3) << " value is output of a while instruction";
+ aliased_buffers.push_back(GetBufferForValue(value));
+ }
+
+ // Uniquify aliased buffers.
+ std::sort(aliased_buffers.begin(), aliased_buffers.end());
+ aliased_buffers.erase(
+ std::unique(aliased_buffers.begin(), aliased_buffers.end()),
+ aliased_buffers.end());
+
+ return aliased_buffers;
+ }
+
+ // Dataflow analysis used to construct the buffer map.
+ const HloDataflowAnalysis& dataflow_;
+
+ // A map containing the set of values contained in each buffer.
+ tensorflow::gtl::FlatMap<BufferNumber,
+ tensorflow::gtl::FlatSet<const HloValue*>>
+ buffers_;
+
+ // A map indicating which buffer each value is contained in.
+ tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
+ value_to_buffer_number_;
+
+ // The buffer number of the next buffer to be created.
+ BufferNumber next_buffer_number_ = 0;
+};
+
HloAliasAnalysis::HloAliasAnalysis(HloModule* module) : module_(module) {}
const HloBuffer& HloAliasAnalysis::GetUniqueBufferAt(
@@ -99,10 +323,11 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
}
} else {
// It's possible for multiple values at this index to have the same
- // HloBuffer. This does not result in non-distictness. To account for this
- // case, add all of the buffers at this index after checking whether each
- // buffer exists at an earlier index. This is a corner case, however, as
- // the number of values at an index is almost always one.
+ // HloBuffer. This does not result in non-distictness. To account for
+ // this case, add all of the buffers at this index after checking
+ // whether each buffer exists at an earlier index. This is a corner
+ // case, however, as the number of values at an index is almost always
+ // one.
std::vector<const HloBuffer*> buffers_at_this_index;
for (const HloValue* value : value_set.values()) {
const HloBuffer* buffer = &GetBufferContainingValue(*value);
@@ -118,15 +343,6 @@ bool HloAliasAnalysis::InstructionBuffersAreDistinct(
return true;
}
-void HloAliasAnalysis::InitializeBufferSets() {
- // Initially define a buffer for every HloValue in the module.
- for (const HloValue& value : dataflow_analysis_->values()) {
- HloBuffer& buffer = NewHloBuffer();
- buffer.AddValue(value);
- value_to_buffer_[&value] = &buffer;
- }
-}
-
Status HloAliasAnalysis::Verify() const {
// Verify consistency between the value_to_buffer_ map and
// HloBuffer::values().
@@ -137,9 +353,8 @@ Status HloAliasAnalysis::Verify() const {
value) != buffer.values().end());
}
- for (const auto& pair : buffers_) {
- const HloBuffer::Id id = pair.first;
- const HloBuffer& buffer = pair.second;
+ for (HloBuffer::Id id = 0; id < buffers_.size(); ++id) {
+ const HloBuffer& buffer = buffers_[id];
TF_RET_CHECK(buffer.id() == id);
HloValue::Id last_value_id = -1;
@@ -152,116 +367,9 @@ Status HloAliasAnalysis::Verify() const {
}
}
- if (!buffers_vector_.empty()) {
- // buffers_vector_ should be a vector of all HloBuffers sorted by id.
- std::vector<const HloBuffer*> buffers;
- for (const auto& id_buffer : buffers_) {
- buffers.push_back(&id_buffer.second);
- }
- std::sort(buffers.begin(), buffers.end(), HloBuffer::IdLessThan);
- TF_RET_CHECK(buffers_vector_ == buffers);
- }
-
- return Status::OK();
-}
-
-Status HloAliasAnalysis::VerifyAgainstReference() const {
- TF_RETURN_IF_ERROR(Verify());
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> reference,
- Run(module_));
- TF_RETURN_IF_ERROR(reference->Verify());
-
- VLOG(2) << "This analysis:";
- XLA_VLOG_LINES(2, ToString());
- VLOG(2) << "Reference:";
- XLA_VLOG_LINES(2, reference->ToString());
-
- // Create map from HloValue in the reference analysis to HloValue in this
- // analysis and vice versa.
- tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> reference_to_this;
- tensorflow::gtl::FlatMap<const HloValue*, const HloValue*> this_to_reference;
- for (const HloValue& value : dataflow_analysis().values()) {
- const HloValue& reference_value =
- reference->dataflow_analysis().GetValueDefinedAt(
- value.defining_instruction(), value.defining_index());
- reference_to_this[&reference_value] = &value;
- this_to_reference[&value] = &reference_value;
- }
-
- TF_RET_CHECK(buffers_.size() == reference->buffers_.size())
- << "Different number of buffers (" << buffers_.size()
- << " != " << reference->buffers_.size() << ")";
- for (const auto& pair : reference->buffers_) {
- const HloBuffer& reference_buffer = pair.second;
-
- // Find the corresponding buffer in the reference by taking the first value
- // in the buffer, finding the corresponding value in the reference, and then
- // finding the buffer holding that value.
- TF_RET_CHECK(!reference_buffer.values().empty());
- const HloValue* reference_value = reference_buffer.values()[0];
- const HloValue* value = reference_to_this.at(reference_value);
- const HloBuffer& buffer = GetBufferContainingValue(*value);
-
- // The buffer and the reference should have the exact same values. To make
- // comparison easy, sort the values in the reference buffer identically to
- // the values in the non-reference buffer (ie, by the corresponding id of
- // the non-reference value).
- std::vector<const HloValue*> reference_values = reference_buffer.values();
- std::sort(reference_values.begin(), reference_values.end(),
- [&reference_to_this](const HloValue* a, const HloValue* b) {
- return reference_to_this.at(a)->id() <
- reference_to_this.at(b)->id();
- });
- TF_RET_CHECK(reference_values.size() == buffer.values().size());
- for (int i = 0; i < buffer.values().size(); ++i) {
- TF_RET_CHECK(*reference_values[i] == *buffer.values()[i])
- << "Buffer:\n " << buffer
- << "\ndoes not have the same values as reference buffer:\n "
- << reference_buffer;
- }
- }
-
return Status::OK();
}
-HloBuffer& HloAliasAnalysis::NewHloBuffer() {
- HloBuffer::Id buffer_id = next_buffer_id_++;
- auto emplaced = buffers_.emplace(std::piecewise_construct,
- std::forward_as_tuple(buffer_id),
- std::forward_as_tuple(buffer_id));
- CHECK(emplaced.second);
-
- buffers_vector_.clear();
-
- return emplaced.first->second;
-}
-
-void HloAliasAnalysis::MoveValueToNewBuffer(const HloValue& value) {
- HloBuffer& new_buffer = NewHloBuffer();
- MoveValueToBuffer(value, &new_buffer);
-
- VLOG(3) << "Moved value " << value.ToShortString() << " into new buffer "
- << new_buffer.id();
-}
-
-void HloAliasAnalysis::MoveValueToBuffer(const HloValue& value,
- HloBuffer* buffer) {
- HloBuffer& old_buffer = GetBufferContainingValue(value);
- CHECK_NE(buffer, &old_buffer);
- VLOG(3) << "Moved value " << value.ToShortString() << " from buffer "
- << old_buffer.id() << " into buffer " << buffer->id();
- old_buffer.RemoveValue(value);
- if (old_buffer.values().empty()) {
- VLOG(3) << "Buffer " << old_buffer.id() << " now empty. Removing.";
- buffers_.erase(old_buffer.id());
- buffers_vector_.clear();
- }
-
- buffer->AddValue(value);
- value_to_buffer_[&value] = buffer;
-}
-
string HloAliasAnalysis::ToString() const {
string out = StrCat("HloAliasAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Buffers at each position:\n");
@@ -290,10 +398,10 @@ string HloAliasAnalysis::ToString() const {
}
StrAppend(&out, " Buffers:\n");
- for (const HloBuffer* buffer : buffers()) {
- StrAppend(&out, " ", buffer->ToString(), "\n");
+ for (const HloBuffer& buffer : buffers()) {
+ StrAppend(&out, " ", buffer.ToString(), "\n");
StrAppend(&out, " positions:\n");
- for (const HloPosition& position : buffer->ComputePositions()) {
+ for (const HloPosition& position : buffer.ComputePositions()) {
StrAppend(&out, " ", position.ToString(), "\n");
}
}
@@ -301,217 +409,6 @@ string HloAliasAnalysis::ToString() const {
return out;
}
-const std::vector<const HloBuffer*>& HloAliasAnalysis::buffers() const {
- if (buffers_vector_.empty()) {
- // Lazily construct vector of buffers.
- buffers_vector_.reserve(buffers_.size());
- for (auto& pair : buffers_) {
- buffers_vector_.push_back(&pair.second);
- }
- std::sort(buffers_vector_.begin(), buffers_vector_.end(),
- HloBuffer::IdLessThan);
- } else {
- CHECK_EQ(buffers_vector_.size(), buffers_.size());
- for (const HloBuffer* buffer : buffers_vector_) {
- DCHECK(ContainsKey(buffers_, buffer->id()));
- DCHECK(&GetBuffer(buffer->id()) == buffer);
- }
- }
- return buffers_vector_;
-}
-
-void HloAliasAnalysis::UpdateAtInstructions(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions) {
- VLOG(4) << "Updated HLO module:";
- XLA_VLOG_LINES(4, module_->ToString());
-
- VLOG(3) << "Before update:";
- XLA_VLOG_LINES(3, ToString());
-
- std::vector<const HloValue*> values_to_update;
- for (const HloInstruction* instruction : instructions) {
- for (auto& pair : dataflow_analysis().GetInstructionValueSet(instruction)) {
- for (const HloValue* value : pair.second.values()) {
- values_to_update.push_back(value);
- }
- }
- }
-
- UpdateBuffersForValues(values_to_update);
-
- VLOG(3) << "After update:";
- XLA_VLOG_LINES(3, ToString());
-}
-
-void HloAliasAnalysis::UpdateAfterChangingOperand(HloInstruction* instruction,
- HloInstruction* old_operand,
- HloInstruction* new_operand) {
- VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
- << old_operand->name() << " => " << new_operand->name() << ")";
-
- dataflow_analysis_->UpdateAfterChangingOperand(instruction, old_operand,
- new_operand);
- TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
-
- VLOG(4) << "Updated dataflow:";
- XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
-
- UpdateAtInstructions({instruction, old_operand, new_operand});
-}
-
-void HloAliasAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
- HloInstruction* new_root) {
- VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
- << new_root->name() << ")";
-
- dataflow_analysis_->UpdateAfterChangingRoot(old_root, new_root);
- TF_DCHECK_OK(dataflow_analysis_->VerifyAgainstReference());
-
- VLOG(4) << "Updated dataflow:";
- XLA_VLOG_LINES(4, dataflow_analysis_->ToString());
-
- UpdateAtInstructions({old_root, new_root});
-}
-
-std::vector<HloBuffer*> HloAliasAnalysis::ComputeAliasedBuffers(
- const HloValue& value) {
- std::vector<HloBuffer*> aliased_buffers;
-
- // Value is init of a while (use is while).
- for (const HloUse& use : value.uses()) {
- VLOG(1) << "use of value " << value.ToShortString() << ": " << use;
- if (use.instruction->opcode() == HloOpcode::kWhile) {
- // Determine the while value that this shares a buffer with.
- const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
- use.instruction, use.operand_index);
- aliased_buffers.push_back(&GetBufferContainingValue(while_value));
- VLOG(3) << " value is init value to a while; must share buffer with "
- "while value "
- << while_value.ToShortString();
- }
- }
-
- // Value is a parameter of a while body/condition.
- if (value.defining_instruction()->opcode() == HloOpcode::kParameter) {
- const HloComputation* computation = value.defining_instruction()->parent();
- const CallGraphNode& call_graph_node =
- dataflow_analysis().call_graph().GetNode(computation);
- for (const CallSite& callsite : call_graph_node.caller_callsites()) {
- if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
- // Call graph must have been flattened.
- CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
-
- const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
- callsite.instruction(), value.defining_index());
- VLOG(3) << " value is parameter value of the body or condition of a "
- "while; must share buffer with while value "
- << while_value.ToShortString();
- aliased_buffers.push_back(&GetBufferContainingValue(while_value));
- }
- }
- }
-
- // Value is the root of a while body.
- for (const HloPosition& position : value.positions()) {
- const HloComputation* computation = position.instruction->parent();
- const CallGraphNode& call_graph_node =
- dataflow_analysis().call_graph().GetNode(computation);
- if (position.instruction == computation->root_instruction()) {
- for (const CallSite& callsite : call_graph_node.caller_callsites()) {
- if (callsite.instruction()->opcode() == HloOpcode::kWhile &&
- callsite.instruction()->while_body() == computation) {
- // Call graph must have been flattened.
- CHECK_EQ(call_graph_node.caller_callsites().size(), 1);
-
- // If the value appears in the root of a while body, then
- // necessarily the value is defined in the body as well.
- CHECK_EQ(value.defining_instruction()->parent(), computation);
-
- const HloValue& while_value = dataflow_analysis().GetUniqueValueAt(
- callsite.instruction(), position.index);
- VLOG(3) << " value is root the body computation of a while; must "
- "share buffer with while value "
- << while_value.ToShortString();
- aliased_buffers.push_back(&GetBufferContainingValue(while_value));
- }
- }
- }
- }
-
- // Value is in the while instruction itself.
- if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
- VLOG(3) << " value is output of a while instruction";
- aliased_buffers.push_back(&GetUniqueBufferAt(value.defining_instruction(),
- value.defining_index()));
- }
-
- // Uniquify aliased buffers.
- std::sort(aliased_buffers.begin(), aliased_buffers.end(),
- HloBuffer::IdLessThan);
- aliased_buffers.erase(
- std::unique(aliased_buffers.begin(), aliased_buffers.end()),
- aliased_buffers.end());
-
- return aliased_buffers;
-}
-
-// This method recomputes the HloBuffer for each of the given HloValues. The
-// method does not necessarily update the HloBuffer of values which share a
-// buffer with the given values, but are not explicitly passed in
-// 'values'. Therefore, the caller must pass in all values which may require an
-// update according to the kind of HLO graph change which occurred: operand
-// changed (UpdateAfterChangingOperand), or root of computation changed
-// (UpdateAfterChangingRoot).
-void HloAliasAnalysis::UpdateBuffersForValues(
- tensorflow::gtl::ArraySlice<const HloValue*> values) {
- for (const HloValue* value : values) {
- VLOG(3) << "Updating buffer for value: " << value->ToShortString();
-
- // Gather the set of buffer with aliasing rules (eg, kWhile) which this
- // value must be contained in due.
- std::vector<HloBuffer*> aliased_buffers = ComputeAliasedBuffers(*value);
-
- HloBuffer& current_buffer = GetBufferContainingValue(*value);
- if (aliased_buffers.empty()) {
- // The buffer containing 'value' aliases no other buffers. If the buffer
- // containing 'value' already only contains 'value', then no change is
- // necessary. If the buffer containing 'value' does contain other values,
- // then remove 'value' from the buffer and create a new buffer containing
- // only 'value'
- if (current_buffer.values().size() == 1) {
- CHECK_EQ(current_buffer.values()[0], value);
- } else {
- MoveValueToNewBuffer(*value);
- }
- } else {
- // If multiple buffers are aliased merge these buffers together into a
- // single buffer (arbitrarily chosen as the first buffer in the vector).
- if (aliased_buffers.size() > 1) {
- for (int64 i = 1; i < aliased_buffers.size(); ++i) {
- // Make copy of values vector because MoveValueToBuffer invalidates
- // the values iterator. The could be done more efficiently by moving
- // all values and once.
- std::vector<const HloValue*> values = aliased_buffers[i]->values();
- for (const HloValue* value : values) {
- MoveValueToBuffer(*value, aliased_buffers[0]);
- }
- }
- aliased_buffers.resize(1);
- }
-
- CHECK_EQ(aliased_buffers.size(), 1);
- HloBuffer* new_buffer = aliased_buffers[0];
-
- if (&current_buffer != new_buffer) {
- MoveValueToBuffer(*value, new_buffer);
- }
- }
-
- VLOG(4) << "Analysis after update:";
- XLA_VLOG_LINES(4, ToString());
- }
-}
-
/* static */
StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloModule* module) {
@@ -524,18 +421,28 @@ StatusOr<std::unique_ptr<HloAliasAnalysis>> HloAliasAnalysis::Run(
HloDataflowAnalysis::Run(module, /*ssa_form=*/true,
/*bitcast_defines_value=*/false));
- alias_analysis->InitializeBufferSets();
-
- VLOG(3) << "After initialization:";
- XLA_VLOG_LINES(3, alias_analysis->ToString());
-
- std::vector<const HloValue*> all_values;
- for (const HloValue& value : alias_analysis->dataflow_analysis().values()) {
- all_values.push_back(&value);
+ BufferValueMap buffer_map(alias_analysis->dataflow_analysis());
+ buffer_map.MergeAliasedBuffers();
+
+ // Create a vector of HloBuffers, one for each set of values in the
+ // BufferValueMap. Create the HloBuffers as a vector of contiguously numbered
+ // buffers.
+ std::vector<BufferValueMap::BufferNumber> sorted_buffer_numbers =
+ buffer_map.ComputeSortedBufferNumbers();
+ alias_analysis->buffers_.reserve(sorted_buffer_numbers.size());
+ HloBuffer::Id next_id = 0;
+ for (BufferValueMap::BufferNumber buffer_number : sorted_buffer_numbers) {
+ auto& value_set = buffer_map.GetValuesInBuffer(buffer_number);
+ std::vector<const HloValue*> sorted_values(value_set.begin(),
+ value_set.end());
+ std::sort(sorted_values.begin(), sorted_values.end(), HloValue::IdLessThan);
+ alias_analysis->buffers_.emplace_back(next_id++, sorted_values);
+ for (const HloValue* value : sorted_values) {
+ alias_analysis->value_to_buffer_[value] =
+ &alias_analysis->buffers_.back();
+ }
}
- alias_analysis->UpdateBuffersForValues(all_values);
-
TF_DCHECK_OK(alias_analysis->Verify());
XLA_VLOG_LINES(1, alias_analysis->ToString());
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index 1b538f6d1c..39554e4664 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -74,7 +74,7 @@ class HloAliasAnalysis {
// Return a vector of all HloBuffers stabily sorted by HloBuffer::Id. This
// vector is lazily computed. Mutating operations on HloAliasAnalysis may
// invalidate the underlying vector requiring recomputation.
- const std::vector<const HloBuffer*>& buffers() const;
+ const std::vector<HloBuffer>& buffers() const { return buffers_; }
// Returns the underlying dataflow analysis used by this alias analysis.
const HloDataflowAnalysis& dataflow_analysis() const {
@@ -90,50 +90,13 @@ class HloAliasAnalysis {
// output of the given instruction.
bool InstructionBuffersAreDistinct(const HloInstruction* instruction) const;
- // Updates the analysis after the operands of 'instruction' have changed or if
- // 'instruction' has been made the root of a computation. Analysis update is
- // not possible if instructions have been added or removed from the graph.
- void UpdateAfterChangingOperand(HloInstruction* instruction,
- HloInstruction* old_operand,
- HloInstruction* new_operand);
- void UpdateAfterChangingRoot(HloInstruction* old_root,
- HloInstruction* new_root);
-
// Compare the dataflow analysis against a clean recomputation of the
// analysis. Returns an error status if there is a mismatch. Useful for
// verifying the correctness after updates to the analysis.
Status VerifyAgainstReference() const;
protected:
- HloAliasAnalysis(HloModule* module);
-
- // Create a new empty HloBuffer.
- HloBuffer& NewHloBuffer();
-
- // Move the given value to the given buffer. The value is removed from it's
- // current buffer.
- void MoveValueToBuffer(const HloValue& value, HloBuffer* buffer);
-
- // Move the given value to a newly created buffer. The value is removed from
- // it's current buffer.
- void MoveValueToNewBuffer(const HloValue& value);
-
- // Construct the initial set of buffer sets where an HloBuffer is created for
- // each HloValue in the module.
- void InitializeBufferSets();
-
- // Compute and return the buffers with aliasing rules (eg, kWhile) which the
- // given value must be contained in.
- std::vector<HloBuffer*> ComputeAliasedBuffers(const HloValue& value);
-
- // Recompute the HloBuffers for the given values.
- void UpdateBuffersForValues(
- tensorflow::gtl::ArraySlice<const HloValue*> values);
-
- // Recompute the HloBuffers for all the values which appear in the output of
- // the given instructions.
- void UpdateAtInstructions(
- tensorflow::gtl::ArraySlice<const HloInstruction*> instructions);
+ explicit HloAliasAnalysis(HloModule* module);
// Verify various invariants of the alias analysis.
Status Verify() const;
@@ -143,20 +106,12 @@ class HloAliasAnalysis {
// The underlying dataflow analysis used by this alias analysis.
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
- // The map of all HloBuffers in the module. We pass around pointers to the
- // mapped HloBuffers, so the underlying container must keep them valid despite
- // mutations touching other map entries.
- std::unordered_map<HloBuffer::Id, HloBuffer> buffers_;
-
// A map indicating which buffer a value is contained in.
tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
// A lazily constructed vector containing all HloBuffers sorted by
// HloBuffer::Id.
- mutable std::vector<const HloBuffer*> buffers_vector_;
-
- // The Id to use for the next HloBuffer.
- int64 next_buffer_id_ = 0;
+ std::vector<HloBuffer> buffers_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
index e2815d6e64..6e311e25fb 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis_test.cc
@@ -87,14 +87,13 @@ class HloAliasAnalysisTest : public HloTestBase {
// constructed.
bool AnyValuesInSameBufferInterfere() {
DependencyHloOrdering ordering(module_.get());
- for (const HloBuffer* buffer : analysis_->buffers()) {
- for (const HloValue* value_a : buffer->values()) {
- for (const HloValue* value_b : buffer->values()) {
+ for (const HloBuffer& buffer : analysis_->buffers()) {
+ for (const HloValue* value_a : buffer.values()) {
+ for (const HloValue* value_b : buffer.values()) {
if (*value_a != *value_b &&
- analysis_->dataflow_analysis().MayInterfere(*value_a, *value_b,
- ordering)) {
+ ordering.MayInterfere(*value_a, *value_b)) {
VLOG(1) << *value_a << " interferes with " << *value_b
- << " in buffer: " << *buffer;
+ << " in buffer: " << buffer;
return true;
}
}
@@ -384,10 +383,7 @@ TEST_F(HloAliasAnalysisTest, SingleWhile) {
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{0})),
- UnorderedElementsAre(GetValueDefinedAt(xla_while, /*index=*/{0}),
- GetValueDefinedAt(body_param, /*index=*/{0}),
- GetValueDefinedAt(cond_param, /*index=*/{0}),
- GetValueDefinedAt(constant1)));
+ UnorderedElementsAre(GetValueDefinedAt(constant1)));
EXPECT_THAT(
GetValuesInBuffer(analysis.GetUniqueBufferAt(xla_while, /*index=*/{1})),
UnorderedElementsAre(GetValueDefinedAt(constant2),
@@ -631,9 +627,9 @@ TEST_F(HloAliasAnalysisTest, SwizzlingWhile) {
// HloBuffers.
EXPECT_THAT(
analysis.buffers(),
- UnorderedElementsAre(&analysis.GetUniqueBufferAt(constant1),
- &analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
- &analysis.GetUniqueBufferAt(cond_constant)));
+ UnorderedElementsAre(analysis.GetUniqueBufferAt(constant1),
+ analysis.GetUniqueBufferAt(tuple, /*index=*/{}),
+ analysis.GetUniqueBufferAt(cond_constant)));
// The tuple elements of the while and the three constant inputs should all be
// smooshed into the same buffer.
@@ -820,127 +816,5 @@ TEST_F(HloAliasAnalysisTest, Bitcast) {
analysis.GetUniqueBufferAt(bitcast));
}
-TEST_F(HloAliasAnalysisTest, UpdateAnalysisForWhile) {
- // Test updating alias analysis after modifying a module with an array shaped
- // while:
- //
- // body(F32[] %param):
- // %negate = Negate(%param)
- //
- // condition(F32[] %param):
- // return Constant(false)
- //
- // entry:
- // %constant = Constant(1.0)
- // %exp = Exp(%constant)
- // return While(%exp, body, condition)
- //
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
- scalar_shape_, HloOpcode::kNegate, body_param));
- HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
-
- // Condition computation trivially returns a constant "false".
- auto cond_builder = HloComputation::Builder("condition");
- auto cond_param = cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- HloComputation* condition =
- module_->AddEmbeddedComputation(cond_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto exp = builder.AddInstruction(
- HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
- auto xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
- module_->AddEntryComputation(builder.Build());
-
- HloAliasAnalysis& analysis = RunAnalysis();
-
- // Sanity check some alias information.
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(body_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(cond_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(negate));
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(xla_while));
-
- // Set the body root to the body_param. Previously it was Negate(body_param).
- body->set_root_instruction(body_param);
-
- // Prior to updating, verify that the analysis is no longer valid.
- Status verify_status = analysis.VerifyAgainstReference();
- EXPECT_FALSE(verify_status.ok());
-
- analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
- /*new_root*/ body_param);
-
- // Analysis should be valid after the update.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- // The exponential should now pass through the body transparently.
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(body_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(cond_param));
- EXPECT_NE(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(negate));
- EXPECT_EQ(analysis.GetUniqueBufferAt(exp),
- analysis.GetUniqueBufferAt(xla_while));
-
- // Now replace the operand of the while with %constant (was %exp).
- TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
- analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
- /*new_operand=*/constant);
-
- // Analysis should be valid after the update.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(body_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(cond_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(xla_while));
- EXPECT_NE(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(exp));
- EXPECT_NE(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(negate));
-
- // And finally make the negate the root of the body again.
- body->set_root_instruction(negate);
- analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
- /*new_root*/ negate);
-
- // Analysis should be valid after the update.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
- analysis.GetUniqueBufferAt(body_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
- analysis.GetUniqueBufferAt(cond_param));
- EXPECT_EQ(analysis.GetUniqueBufferAt(negate),
- analysis.GetUniqueBufferAt(xla_while));
- EXPECT_EQ(analysis.GetUniqueBufferAt(constant),
- analysis.GetUniqueBufferAt(negate));
-
- auto value_of = [&analysis](const HloInstruction* instruction) {
- return &analysis.dataflow_analysis().GetValueDefinedAt(instruction);
- };
- EXPECT_THAT(analysis.GetUniqueBufferAt(negate).values(),
- UnorderedElementsAre(value_of(body_param), value_of(cond_param),
- value_of(negate), value_of(constant),
- value_of(xla_while)));
-}
-
-// Test update tuple element.
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index 2bfdd9156a..e16413f361 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -36,22 +36,6 @@ namespace xla {
using ::tensorflow::str_util::Join;
using ::tensorflow::strings::StrCat;
-void HloBuffer::AddValue(const HloValue& value) {
- values_.push_back(&value);
- // Sort vector and remove duplicates.
- std::sort(values_.begin(), values_.end(), HloValue::IdLessThan);
- values_.erase(std::unique(values_.begin(), values_.end(), HloValue::IdEqual),
- values_.end());
-}
-
-void HloBuffer::RemoveValue(const HloValue& value) {
- // The values are sorted, so finding the value could be done in log(n) time
- // with a binary search.
- auto it = std::find(values_.begin(), values_.end(), &value);
- CHECK(it != values_.end());
- values_.erase(it);
-}
-
bool HloBuffer::operator==(const HloBuffer& other) const {
bool equal = id() == other.id();
if (equal) {
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.h b/tensorflow/compiler/xla/service/hlo_buffer.h
index cb961e1601..4873463b2e 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.h
+++ b/tensorflow/compiler/xla/service/hlo_buffer.h
@@ -84,22 +84,15 @@ class HloBuffer {
return a->id() == b->id();
}
- HloBuffer(Id id) : id_(id) {}
+ HloBuffer(Id id, tensorflow::gtl::ArraySlice<const HloValue*> values)
+ : id_(id), values_(values.begin(), values.end()) {}
// Return the unique identifier for this HloBuffer.
Id id() const { return id_; }
- // Add a value to the set of values held by this buffer. Also adds the
- // HloPositions of the value to the positions vector of the buffer. If the
- // buffer already contains this value, then this method is a nop.
- void AddValue(const HloValue& value);
- void RemoveValue(const HloValue& value);
-
// Return all values contained in this buffer.
const std::vector<const HloValue*>& values() const { return values_; }
- std::vector<HloPosition> ComputePositions() const;
-
// Return the unique HLO value in the buffer. CHECK fails if the buffer does
// not contain exactly one value.
const HloValue& GetUniqueValue() const {
@@ -107,6 +100,8 @@ class HloBuffer {
return *values_[0];
}
+ std::vector<HloPosition> ComputePositions() const;
+
string ToString() const;
bool operator==(const HloBuffer& other) const;
@@ -118,7 +113,7 @@ class HloBuffer {
// The set of values contained in this buffer. Vector contains no duplicates
// and is sorted stably by HloValue::Id.
- std::vector<const HloValue*> values_;
+ const std::vector<const HloValue*> values_;
};
std::ostream& operator<<(std::ostream& out, const HloBuffer& buffer);
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index ea8b239e10..2be1645f1b 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -67,6 +67,22 @@ HloValue& HloDataflowAnalysis::GetValueDefinedAt(
return GetUniqueValueAt(instruction, index);
}
+HloValue* HloDataflowAnalysis::NewHloValue(HloInstruction* instruction,
+ const ShapeIndex& index,
+ bool is_phi) {
+ const int64 value_id = next_value_id_++;
+ auto emplaced = values_.emplace(
+ std::piecewise_construct, std::forward_as_tuple(value_id),
+ std::forward_as_tuple(value_id, instruction, index, is_phi));
+ CHECK(emplaced.second);
+
+ return &emplaced.first->second;
+}
+
+void HloDataflowAnalysis::DeleteHloValue(HloValue::Id value_id) {
+ values_.erase(value_id);
+}
+
string HloDataflowAnalysis::ToString() const {
string out = StrCat("HloDataflowAnalysis, module ", module_->name(), "\n");
StrAppend(&out, " Instruction value sets:\n");
@@ -99,20 +115,96 @@ string HloDataflowAnalysis::ToString() const {
}
}
StrAppend(&out, " HloValues:\n");
- for (const HloValue& value : values()) {
- StrAppend(&out, value.ToString(/*indent=*/4));
+ for (const HloValue* value : values()) {
+ StrAppend(&out, value->ToString(/*indent=*/4));
+ }
+ return out;
+}
+
+bool HloDataflowAnalysis::Phi(
+ HloInstruction* instruction,
+ tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
+ CHECK(ssa_form_);
+
+ for (const InstructionValueSet* input : inputs) {
+ DCHECK(ShapeUtil::Compatible(instruction->shape(), input->shape()));
}
- StrAppend(&out, " Phi resolutions:\n");
- for (const HloValue& value : values()) {
- if (value.is_phi()) {
- const HloValue* resolved_value = ResolvePhi(value);
- StrAppend(&out, " ", value.ToShortString(), " => ",
- resolved_value == nullptr ? "UNKNOWN"
- : resolved_value->ToShortString(),
- "\n");
+
+ bool changed = false;
+ for (auto& pair : GetInstructionValueSet(instruction)) {
+ const ShapeIndex& index = pair.first;
+ HloValueSet& value_set = pair.second;
+
+ // Positions with phi values should never have more than one value in the
+ // value set.
+ CHECK_LE(value_set.values().size(), 1);
+ const HloValue* current_value =
+ value_set.values().size() == 1 ? value_set.values()[0] : nullptr;
+
+ // Construct a vector of unique value IDs of the inputs.
+ std::vector<HloValue::Id> input_value_ids;
+ for (const InstructionValueSet* input : inputs) {
+ for (const HloValue* value : input->element(index).values()) {
+ input_value_ids.push_back(value->id());
+ }
+ }
+ std::sort(input_value_ids.begin(), input_value_ids.end());
+ input_value_ids.erase(
+ std::unique(input_value_ids.begin(), input_value_ids.end()),
+ input_value_ids.end());
+
+ // Remove the existing phi value (if it exists). The phi can be its own
+ // input, for example, in while body parameters where the body passes
+ // through the parameter value.
+ bool current_value_defined_here =
+ (current_value != nullptr &&
+ current_value->defining_instruction() == instruction &&
+ current_value->defining_index() == index);
+ if (current_value_defined_here) {
+ CHECK(current_value->is_phi());
+ auto it = std::find(input_value_ids.begin(), input_value_ids.end(),
+ current_value->id());
+ if (it != input_value_ids.end()) {
+ input_value_ids.erase(it);
+ }
+ }
+
+ if (input_value_ids.empty()) {
+ // A value set which has at least one element should never have its value
+ // set reduced to zero elements. During dataflow value sets only can go
+ // from empty to non-empty, not the reverse.
+ CHECK_EQ(value_set.values().size(), 0)
+ << "Instruction " << instruction->name() << " at index " << index
+ << " previously had non-empty value set. Value set: " << value_set;
+ } else if (input_value_ids.size() == 1) {
+ // Only a single value reaches this point. There should be no phi, and
+ // this value set should contain this single value.
+ const HloValue& new_value = GetValue(input_value_ids[0]);
+ if (current_value == nullptr) {
+ value_set.Clear();
+ value_set.AddValue(&new_value);
+ changed = true;
+ } else if (current_value != &new_value) {
+ if (current_value_defined_here) {
+ // Remove the existing phi.
+ DeleteHloValue(current_value->id());
+ }
+ value_set.Clear();
+ value_set.AddValue(&new_value);
+ changed = true;
+ }
+ } else {
+ // Multiple distinct values reach this point. A phi value is
+ // necessary.
+ CHECK_GT(input_value_ids.size(), 1);
+ if (current_value == nullptr || !current_value->is_phi()) {
+ value_set.Clear();
+ value_set.AddValue(NewHloValue(instruction, index, /*is_phi=*/true));
+ changed = true;
+ }
}
}
- return out;
+ return changed;
}
const HloValue& HloDataflowAnalysis::GetValue(HloValue::Id value_id) const {
@@ -142,129 +234,6 @@ HloValueSet& HloDataflowAnalysis::GetValueSet(const HloPosition& position) {
return GetValueSet(position.instruction, position.index);
}
-void HloDataflowAnalysis::UpdateAfterChangingOperand(
- HloInstruction* instruction, HloInstruction* old_operand,
- HloInstruction* new_operand) {
- CHECK(std::find(instruction->operands().begin(),
- instruction->operands().end(),
- new_operand) != instruction->operands().end());
- VLOG(1) << "UpdateAfterChangingOperand(" << instruction->name() << ", "
- << old_operand->name() << " => " << new_operand->name() << ")";
-
- std::vector<HloInstruction*> to_update = {instruction};
-
- // If the instruction calls any computations then add the parameters of called
- // computation to capture any changes to the dataflow into the subcomputation
- // introduced by the new operand.
- for (HloComputation* computation : instruction->called_computations()) {
- to_update.insert(to_update.end(),
- computation->parameter_instructions().begin(),
- computation->parameter_instructions().end());
- }
-
- UpdateInstructionsAndPropagate(to_update);
-
- // The uses of the values in the old and new operand may have changed. Uses of
- // other HloValues are updated in UpdateInstructionsAndPropagate.
- for (auto& pair : GetInstructionValueSet(old_operand)) {
- for (const HloValue* value : pair.second.values()) {
- GetValue(value->id()).RecomputeUses();
- }
- }
- for (auto& pair : GetInstructionValueSet(new_operand)) {
- for (const HloValue* value : pair.second.values()) {
- GetValue(value->id()).RecomputeUses();
- }
- }
-
- TF_DCHECK_OK(VerifyAgainstReference());
-}
-
-void HloDataflowAnalysis::UpdateAfterChangingRoot(HloInstruction* old_root,
- HloInstruction* new_root) {
- VLOG(1) << "UpdateAfterChangingRoot(" << old_root->name() << " => "
- << new_root->name() << ")";
-
- CHECK_EQ(new_root, new_root->parent()->root_instruction());
- CHECK_EQ(new_root->parent(), old_root->parent());
-
- std::vector<HloInstruction*> to_update = {old_root, new_root};
-
- const CallGraphNode& call_graph_node =
- call_graph_->GetNode(new_root->parent());
- for (const CallSite& callsite : call_graph_node.caller_callsites()) {
- if (callsite.instruction()->opcode() == HloOpcode::kCall) {
- to_update.push_back(callsite.instruction());
- } else if (callsite.instruction()->opcode() == HloOpcode::kWhile) {
- // Add the while itself, and the body and condition parameters.
- to_update.push_back(callsite.instruction());
- to_update.push_back(
- callsite.instruction()->while_body()->parameter_instruction(0));
- to_update.push_back(
- callsite.instruction()->while_condition()->parameter_instruction(0));
- }
- }
-
- UpdateInstructionsAndPropagate(to_update);
-
- TF_DCHECK_OK(VerifyAgainstReference());
-}
-
-const HloValue* HloDataflowAnalysis::ResolvePhi(const HloValue& phi) const {
- CHECK(phi.is_phi());
-
- tensorflow::gtl::FlatSet<const HloValue*> visited;
- std::queue<const HloValue*> worklist;
- auto add_to_worklist = [&worklist, &visited](const HloValue* v) {
- if (visited.insert(v).second) {
- // 'v' was not previously in visited.
- worklist.push(v);
- }
- };
- add_to_worklist(&phi);
-
- const HloValue* resolved_value = nullptr;
- while (!worklist.empty()) {
- const HloValue* value = worklist.front();
- worklist.pop();
-
- if (!value->is_phi()) {
- if (resolved_value == nullptr) {
- resolved_value = value;
- } else if (resolved_value != value) {
- return nullptr;
- }
- } else {
- for (const HloValue* input : phi_inputs_.at(value)) {
- add_to_worklist(input);
- }
- }
- }
- return resolved_value;
-}
-
-void HloDataflowAnalysis::UpdatePhiInputs(
- const HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs) {
- CHECK(ssa_form_);
- for (auto& pair : GetInstructionValueSet(instruction)) {
- const ShapeIndex& index = pair.first;
- const HloValue& phi_value = GetUniqueValueAt(instruction, index);
- auto& phi_inputs = phi_inputs_.at(&phi_value);
- phi_inputs.clear();
- for (const InstructionValueSet* input : inputs) {
- for (const HloValue* value : input->element(index).values()) {
- // The number of phi inputs is typically 2, and virtually always very
- // small.
- if (std::find(phi_inputs.begin(), phi_inputs.end(), value) ==
- phi_inputs.end()) {
- phi_inputs.push_back(value);
- }
- }
- }
- }
-}
-
bool HloDataflowAnalysis::UpdateBitcastValueSet(HloInstruction* bitcast) {
CHECK_EQ(bitcast->opcode(), HloOpcode::kBitcast);
const InstructionValueSet& operand_set =
@@ -380,8 +349,7 @@ bool HloDataflowAnalysis::UpdateParameterValueSet(HloInstruction* parameter) {
}
if (ssa_form_ && called_from_while) {
- UpdatePhiInputs(parameter, inputs);
- return false;
+ return Phi(parameter, inputs);
} else {
return GetInstructionValueSet(parameter).AssignUnionOf(inputs);
}
@@ -439,8 +407,7 @@ bool HloDataflowAnalysis::UpdateWhileValueSet(HloInstruction* xla_while) {
&GetInstructionValueSet(xla_while->while_body()->root_instruction()),
&GetInstructionValueSet(xla_while->operand(0))};
if (ssa_form_) {
- UpdatePhiInputs(xla_while, inputs);
- return false;
+ return Phi(xla_while, inputs);
} else {
return GetInstructionValueSet(xla_while).AssignUnionOf(inputs);
}
@@ -487,38 +454,7 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
VLOG(3) << "Worklist top: " << instruction->name();
VLOG(3) << ToString();
- // The updating of the instruction value set below in
- // UpdateInstructionValueSet does not update HloValue::positions(). To
- // perform the positions() update remove all positions in 'instruction' from
- // the HloValues in 'instruction's value set prior to the update, then after
- // the update add the new positions back in. There is likely a more
- // efficient way of doing this.
- for (auto& pair : GetInstructionValueSet(instruction)) {
- const ShapeIndex& index = pair.first;
- HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction() != instruction) {
- // Use GetValue for a non-const HloValue reference.
- GetValue(value->id()).RemovePosition(instruction, index);
- }
- }
- }
-
- bool changed = UpdateInstructionValueSet(instruction);
-
- // Add the positions back in.
- for (auto& pair : GetInstructionValueSet(instruction)) {
- const ShapeIndex& index = pair.first;
- HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (value->defining_instruction() != instruction) {
- // Use GetValue for a non-const HloValue reference.
- GetValue(value->id()).AddPosition(instruction, index);
- }
- }
- }
-
- if (!changed) {
+ if (!UpdateInstructionValueSet(instruction)) {
// No change to the instruction's value set.
VLOG(4) << "No change.";
continue;
@@ -531,12 +467,16 @@ void HloDataflowAnalysis::UpdateInstructionsAndPropagate(
for (HloInstruction* user : instruction->users()) {
worklist.push(user);
- // If user calls a computation, then the respective parameter(s) of the
- // computation need to be updated.
+ // If user sequentially calls a computation, then the respective
+ // parameter(s) of the computation need to be updated.
for (HloComputation* called_computation : user->called_computations()) {
- for (int64 operand_number : user->OperandIndices(instruction)) {
- worklist.push(
- called_computation->parameter_instruction(operand_number));
+ const CallGraphNode& call_graph_node =
+ call_graph_->GetNode(called_computation);
+ if (call_graph_node.context() == CallContext::kSequential) {
+ for (int64 operand_number : user->OperandIndices(instruction)) {
+ worklist.push(
+ called_computation->parameter_instruction(operand_number));
+ }
}
}
}
@@ -574,25 +514,10 @@ InstructionValueSet& HloDataflowAnalysis::GetInstructionValueSet(
}
Status HloDataflowAnalysis::InitializeInstructionValueSets() {
- // Gather the values to create before creating them. This is done because we
- // want to allocate the vector of values only once so references to elements
- // are stable.
- struct ValueToCreate {
- HloInstruction* instruction;
- ShapeIndex index;
- bool is_phi;
- };
- std::vector<ValueToCreate> values_to_create;
-
for (const std::unique_ptr<HloComputation>& computation :
module_->computations()) {
const CallGraphNode& call_graph_node =
call_graph_->GetNode(computation.get());
- bool called_from_while = std::any_of(
- call_graph_node.caller_callsites().begin(),
- call_graph_node.caller_callsites().end(), [](const CallSite& cs) {
- return cs.instruction()->opcode() == HloOpcode::kWhile;
- });
for (const std::unique_ptr<HloInstruction>& instruction :
computation->instructions()) {
@@ -603,20 +528,22 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// Lambda to set the value set to define all values in the output of the
// instruction.
- auto define_all_values = [this, &instruction,
- &values_to_create](bool is_phi = false) {
+ auto define_all_values = [this, &instruction](bool is_phi = false) {
for (auto& pair : GetInstructionValueSet(instruction.get())) {
const ShapeIndex& index = pair.first;
- values_to_create.push_back({instruction.get(), index, is_phi});
+ HloValue* value =
+ NewHloValue(instruction.get(), index, /*is_phi=*/false);
+ GetValueSet(instruction.get(), index).AddValue(value);
}
};
// Lambda to set the value set to define only the top-level buffer in the
// output of the instruction. Any other values flow from the operands of
// the instruction (or from cross-computation dataflow).
- auto define_top_level_only = [this, &instruction, &values_to_create]() {
- values_to_create.push_back(
- {instruction.get(), /*index=*/{}, /*is_phi=*/false});
+ auto define_top_level_only = [this, &instruction]() {
+ HloValue* value =
+ NewHloValue(instruction.get(), /*index=*/{}, /*is_phi=*/false);
+ GetValueSet(instruction.get(), /*index=*/{}).AddValue(value);
};
switch (instruction->opcode()) {
@@ -626,10 +553,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
break;
case HloOpcode::kWhile:
- if (ssa_form_) {
- define_all_values(/*is_phi=*/true);
- }
- break;
case HloOpcode::kCall:
case HloOpcode::kGetTupleElement:
// These instructions define no values. The values in their output
@@ -654,10 +577,6 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
// values in their output. Otherwise the values of the parameter
// come from the caller (eg, operands to the kCall instruction).
define_all_values();
- } else if (call_graph_node.context() == CallContext::kSequential &&
- called_from_while && ssa_form_) {
- // Parameters of while bodies and conditions are phis.
- define_all_values(/*is_phi=*/true);
}
break;
case HloOpcode::kCopy:
@@ -674,164 +593,9 @@ Status HloDataflowAnalysis::InitializeInstructionValueSets() {
}
}
- // Reserve the vector ahead of time so references to elements are stable.
- values_.reserve(values_to_create.size());
- for (int64 i = 0; i < values_to_create.size(); ++i) {
- const ValueToCreate& to_create = values_to_create[i];
- values_.emplace_back(/*id=*/i, to_create.instruction, to_create.index,
- to_create.is_phi);
- const HloValue& value = values_.back();
- GetValueSet(to_create.instruction, to_create.index).AddValue(&value);
- if (value.is_phi()) {
- phi_inputs_[&value] = {};
- }
- }
return Status::OK();
}
-bool HloDataflowAnalysis::IsDefinedBefore(const HloValue& a, const HloValue& b,
- const HloOrdering& ordering) const {
- // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
- // is live into the module.
- if (b.defining_instruction()->parent() == module_->entry_computation() &&
- b.defining_instruction()->opcode() == HloOpcode::kParameter) {
- return false;
- }
-
- // Phi values require special handling. Because XLA does not have a phi
- // instruction, the definition instruction of the phis values are
- // placeholders: either the subcomputation parameter (body or condition) or
- // the while instruction. However, the program point where these values are
- // logically defined does not necessarily coincide exactly with program point
- // of these place-holder instructions. So we explicitly define the following
- // order for phi values:
- //
- // body/condition parameter phi:
- // Defined before all values defined in its computation excepting other
- // phis.
- //
- // while phi:
- // defined after all values defined in the condition or body.
- //
- auto is_body_or_condition_phi = [](const HloValue& v) {
- return v.is_phi() &&
- v.defining_instruction()->opcode() == HloOpcode::kParameter;
- };
- if (is_body_or_condition_phi(a) && !is_body_or_condition_phi(b) &&
- call_graph_->InstructionIsNestedIn(b.defining_instruction(),
- a.defining_instruction()->parent())) {
- return true;
- }
- if (is_body_or_condition_phi(b) &&
- call_graph_->InstructionIsNestedIn(a.defining_instruction(),
- b.defining_instruction()->parent())) {
- return false;
- }
-
- // If 'b' is a while phi and 'a' is in the body or condition, then 'a'
- // executes before 'b'.
- if (b.is_phi() && b.defining_instruction()->opcode() == HloOpcode::kWhile &&
- (call_graph_->InstructionIsNestedIn(
- a.defining_instruction(), b.defining_instruction()->while_body()) ||
- call_graph_->InstructionIsNestedIn(
- a.defining_instruction(),
- b.defining_instruction()->while_condition()))) {
- return true;
- }
-
- return ordering.ExecutesBefore(a.defining_instruction(),
- b.defining_instruction());
-}
-
-bool HloDataflowAnalysis::UseIsBeforeValueDefinition(
- const HloUse& use, const HloValue& value,
- const HloOrdering& ordering) const {
- if (ordering.ExecutesBefore(use.instruction, value.defining_instruction())) {
- return true;
- }
-
- // If the use is at the instruction where the value is defined, then the use
- // is before the def if the instruction allows buffer sharing (in place
- // computation).
- if (use.instruction == value.defining_instruction() &&
- CanShareOperandBufferWithUser(
- use.instruction->mutable_operand(use.operand_number),
- use.operand_index, value.defining_instruction(),
- value.defining_index())) {
- return true;
- }
-
- // The use at a while is an input to a phi, and logically occurs before values
- // are defined in the body or condition computations.
- if (use.instruction->opcode() == HloOpcode::kWhile) {
- const HloInstruction* xla_while = use.instruction;
- if (call_graph_->InstructionIsNestedIn(value.defining_instruction(),
- xla_while->while_body()) ||
- call_graph_->InstructionIsNestedIn(value.defining_instruction(),
- xla_while->while_condition())) {
- return true;
- }
- }
-
- // Similarly if the value is defined at a while, it logically occurs after any
- // uses in the body or condition computations.
- if (value.defining_instruction()->opcode() == HloOpcode::kWhile) {
- CHECK(ssa_form_);
- const HloInstruction* xla_while = value.defining_instruction();
- if (call_graph_->InstructionIsNestedIn(use.instruction,
- xla_while->while_body()) ||
- call_graph_->InstructionIsNestedIn(use.instruction,
- xla_while->while_condition())) {
- return true;
- }
- }
- return false;
-}
-
-bool HloDataflowAnalysis::LiveRangeStrictlyBefore(
- const HloValue& a, const HloValue& b, const HloOrdering& ordering) const {
- VLOG(4) << "LiveRangeStrictlyBefore(a = " << a.ToShortString()
- << ", b = " << b.ToShortString() << ")";
- if (!IsDefinedBefore(a, b, ordering)) {
- VLOG(4) << "a not defined before b";
- return false;
- }
-
- // Live-out values from the module can never have ranges strictly before any
- // other value.
- if (a.live_out_of_module()) {
- VLOG(4) << "a is live out of module";
- return false;
- }
-
- // Live-out values of computations can never have ranges strictly before any
- // other value in the computation (including values nested in
- // subcomputations).
- if (a.live_out_of_computation() &&
- call_graph_->InstructionIsNestedIn(b.defining_instruction(),
- a.defining_instruction()->parent())) {
- VLOG(4) << "a is live out of computation containing b";
- return false;
- }
-
- // All uses of 'a' must be before 'b' is defined.
- for (const HloUse& use : a.uses()) {
- if (!UseIsBeforeValueDefinition(use, b, ordering)) {
- VLOG(4) << "use of a (" << use << ") not before b is defined";
- return false;
- }
- }
-
- return true;
-}
-
-bool HloDataflowAnalysis::MayInterfere(const HloValue& a, const HloValue& b,
- const HloOrdering& ordering) const {
- // Buffers without disjoint liveness may interfere.
- return !LiveRangeStrictlyBefore(a, b, ordering) &&
- !LiveRangeStrictlyBefore(b, a, ordering);
-}
-
/* static */
StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
HloModule* module, bool ssa_form, bool bitcast_defines_value) {
@@ -855,6 +619,33 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
}
dataflow_analysis->UpdateInstructionsAndPropagate(all_instructions);
+ // Add in positions to all values.
+ for (const std::unique_ptr<HloComputation>& computation :
+ module->computations()) {
+ for (const std::unique_ptr<HloInstruction>& instruction :
+ computation->instructions()) {
+ for (const auto& pair :
+ dataflow_analysis->GetInstructionValueSet(instruction.get())) {
+ const ShapeIndex& index = pair.first;
+ const HloValueSet& value_set = pair.second;
+ for (const HloValue* value : value_set.values()) {
+ if (value->defining_instruction() != instruction.get()) {
+ dataflow_analysis->GetValue(value->id())
+ .AddPosition(instruction.get(), index);
+ }
+ }
+ }
+ }
+ }
+
+ // Construct vector of values.
+ dataflow_analysis->values_vector_.reserve(dataflow_analysis->values_.size());
+ for (auto& pair : dataflow_analysis->values_) {
+ dataflow_analysis->values_vector_.push_back(&pair.second);
+ }
+ std::sort(dataflow_analysis->values_vector_.begin(),
+ dataflow_analysis->values_vector_.end(), HloValue::IdLessThan);
+
TF_DCHECK_OK(dataflow_analysis->Verify());
XLA_VLOG_LINES(1, dataflow_analysis->ToString());
@@ -865,14 +656,14 @@ StatusOr<std::unique_ptr<HloDataflowAnalysis>> HloDataflowAnalysis::Run(
Status HloDataflowAnalysis::Verify() const {
// Verify each HloValue appears in the value sets that the value's positions()
// indicate.
- for (const HloValue& value : values()) {
- for (const HloPosition& position : value.positions()) {
+ for (const HloValue* value : values()) {
+ for (const HloPosition& position : value->positions()) {
const HloValueSet& value_set = GetValueSet(position);
TF_RET_CHECK(std::find(value_set.values().begin(),
value_set.values().end(),
- &value) != value_set.values().end())
+ value) != value_set.values().end())
<< "Value set at position " << position << " does not contain value "
- << value.ToShortString();
+ << value->ToShortString();
}
}
@@ -898,75 +689,4 @@ Status HloDataflowAnalysis::Verify() const {
return Status::OK();
}
-Status HloDataflowAnalysis::VerifyAgainstReference() const {
- TF_RETURN_IF_ERROR(Verify());
-
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloDataflowAnalysis> reference,
- Run(module_, ssa_form_, bitcast_defines_value_));
- TF_RETURN_IF_ERROR(reference->Verify());
-
- VLOG(2) << "This analysis:";
- XLA_VLOG_LINES(2, ToString());
- VLOG(2) << "Reference:";
- XLA_VLOG_LINES(2, reference->ToString());
-
- // Verify value sets in each position are identical.
- for (const auto& computation : module_->computations()) {
- for (const auto& instruction : computation->instructions()) {
- for (const auto& pair : GetInstructionValueSet(instruction.get())) {
- const ShapeIndex& index = pair.first;
- const HloValueSet& value_set = pair.second;
- const HloValueSet& reference_value_set =
- reference->GetValueSet(instruction.get(), index);
-
- auto value_in_set = [](const HloValue& v, const HloValueSet& vset) {
- return std::find_if(vset.values().begin(), vset.values().end(),
- [&v](const HloValue* w) { return *w == v; }) !=
- vset.values().end();
- };
-
- for (const HloValue* value : value_set.values()) {
- TF_RET_CHECK(value_in_set(*value, reference_value_set))
- << "Value " << value->ToShortString()
- << " does not exist in reference";
- }
- for (const HloValue* reference_value : reference_value_set.values()) {
- TF_RET_CHECK(value_in_set(*reference_value, value_set))
- << "Value " << reference_value->ToShortString()
- << " only exists in reference";
- }
- }
- }
- }
-
- // Verify all phis resolve identically and uses are identical.
- for (const HloValue& value : values()) {
- const HloValue& reference_value = reference->GetValueDefinedAt(
- value.defining_instruction(), value.defining_index());
- TF_RET_CHECK(value.is_phi() == reference_value.is_phi());
- if (value.is_phi()) {
- const HloValue* resolved_value = ResolvePhi(value);
- const HloValue* reference_resolved_value =
- reference->ResolvePhi(reference_value);
- if (resolved_value == nullptr) {
- TF_RET_CHECK(reference_resolved_value == nullptr);
- } else {
- TF_RET_CHECK(reference_resolved_value != nullptr);
- TF_RET_CHECK(*reference_resolved_value == *resolved_value);
- }
- }
-
- for (const HloUse& use : value.uses()) {
- TF_RET_CHECK(std::find(reference_value.uses().begin(),
- reference_value.uses().end(),
- use) != reference_value.uses().end());
- }
- for (const HloUse& reference_use : reference_value.uses()) {
- TF_RET_CHECK(std::find(value.uses().begin(), value.uses().end(),
- reference_use) != value.uses().end());
- }
- }
- return Status::OK();
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
index 7781cc58a3..aae257dd09 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.h
@@ -88,10 +88,10 @@ class HloDataflowAnalysis {
// given position.
const HloValueSet& GetValueSet(const HloInstruction* instruction,
const ShapeIndex& index = {}) const;
- HloValueSet& GetValueSet(const HloInstruction* instruction,
- const ShapeIndex& index = {});
const HloValueSet& GetValueSet(const HloPosition& position) const;
HloValueSet& GetValueSet(const HloPosition& position);
+ HloValueSet& GetValueSet(const HloInstruction* instruction,
+ const ShapeIndex& index = {});
// Return the unique value in the HloValueSet at the given instruction and
// shape index. CHECKs if the value set does not contain a exactly one value.
@@ -108,49 +108,11 @@ class HloDataflowAnalysis {
const HloValue& GetValue(HloValue::Id value_id) const;
HloValue& GetValue(HloValue::Id value_id);
- // Returns whether the given values interfere assuming the given HLO
- // ordering. Two values interfere if they may both be simultaneously live.
- bool MayInterfere(const HloValue& a, const HloValue& b,
- const HloOrdering& ordering) const;
-
- // Overload which takes HloValue:Ids.
- bool MayInterfere(HloValue::Id a, HloValue::Id b,
- const HloOrdering& ordering) const {
- return MayInterfere(GetValue(a), GetValue(b), ordering);
- }
-
// Return the total number of HloValues.
int64 value_count() const { return values_.size(); }
- // Return a vector of all HloValues.
- const std::vector<HloValue>& values() const { return values_; }
-
- // Updates the dataflow after the changing an operand of
- // 'instruction'. Dataflow update is not possible if instructions have been
- // added or removed from the graph.
- void UpdateAfterChangingOperand(HloInstruction* instruction,
- HloInstruction* old_operand,
- HloInstruction* new_operand);
-
- // Updates the dataflow after the changing the root of a computation from
- // 'old_root' to 'new_root'.
- void UpdateAfterChangingRoot(HloInstruction* old_root,
- HloInstruction* new_root);
-
- // Returns the non-phi HloValue that is the unique (transitive) input to the
- // given phi. If no such HloValue exists (there are multiple inputs to the
- // phi) then nullptr is returned. This is computed by all walking the inputs
- // of the given phi value until non-phi HloValue(s) are encountered.
- const HloValue* ResolvePhi(const HloValue& phi) const;
- const HloValue* ResolvePhi(const HloInstruction* instruction,
- const ShapeIndex& index = {}) const {
- return ResolvePhi(GetValueDefinedAt(instruction, index));
- }
-
- // Compare the dataflow analysis against a clean recomputation of the
- // analysis. Returns an error status if there is a mismatch. Useful for
- // verifying the correctness after updates to the analysis.
- Status VerifyAgainstReference() const;
+ // Return a vector of all HloValues stabily sorted by HloValue::Id.
+ const std::vector<const HloValue*>& values() const { return values_vector_; }
// Return the call graph used for computing the dataflow.
const CallGraph& call_graph() const { return *call_graph_; }
@@ -161,6 +123,13 @@ class HloDataflowAnalysis {
HloDataflowAnalysis(HloModule* module, bool ssa_form,
bool bitcast_defines_value = false);
+ // Returns a new HloValue defined at the given instruction and shape index.
+ HloValue* NewHloValue(HloInstruction* instruction, const ShapeIndex& index,
+ bool is_phi = false);
+
+ // Delete the HloValue with the given ID.
+ void DeleteHloValue(HloValue::Id value_id);
+
// Constructs and initializes the InstructionValueSets of all instructions to
// contain exactly the HloValues defined by each instruction. These values can
// then propagated throughout the HLO graph by calling
@@ -187,10 +156,11 @@ class HloDataflowAnalysis {
void UpdateInstructionsAndPropagate(
tensorflow::gtl::ArraySlice<HloInstruction*> instructions);
- // Sets the inputs of the given phi to given value(s).
- void UpdatePhiInputs(
- const HloInstruction* instruction,
- tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
+ // Return the result of the SSA Phi function applied to the given inputs at
+ // the given instruction. If skip_top_level is true, then the top level of the
+ // value set of 'instruction' is not modified.
+ bool Phi(HloInstruction* instruction,
+ tensorflow::gtl::ArraySlice<const InstructionValueSet*> inputs);
// Updates the positions of the HloValues in the output of the given
// instruction. This should be called after the instruction value set of
@@ -203,20 +173,6 @@ class HloDataflowAnalysis {
HloInstruction* instruction, const InstructionValueSet& new_value_set,
const InstructionValueSet* prev_value_set = nullptr);
- // Returns true if the live range of the given value 'a' is strictly before
- // the live range of value 'b' using the given HLO ordering.
- bool LiveRangeStrictlyBefore(const HloValue& a, const HloValue& b,
- const HloOrdering& ordering) const;
-
- // Returns whether the value 'a' is defined before the value 'b' under the
- // given ordering.
- bool IsDefinedBefore(const HloValue& a, const HloValue& b,
- const HloOrdering& ordering) const;
-
- // Returns whether the given use is before the given value definition.
- bool UseIsBeforeValueDefinition(const HloUse& use, const HloValue& value,
- const HloOrdering& ordering) const;
-
// Verify various invariants of the dataflow analysis.
Status Verify() const;
@@ -226,19 +182,19 @@ class HloDataflowAnalysis {
std::unique_ptr<CallGraph> call_graph_;
- // Array of all values in the module. This is allocated once at analysis
- // construction time so HloValue references are stable. Updates to the
- // analysis via UpdateAfterChangingOperand and UpdateAfterChangingRoot do not
- // result in the creation or destruction of any HloValues.
- std::vector<HloValue> values_;
-
- // Map hold the inputs to each phi value in the module. Used by ResolvePhi.
- tensorflow::gtl::FlatMap<const HloValue*,
- tensorflow::gtl::InlinedVector<const HloValue*, 2>>
- phi_inputs_;
+ // The map of all HloValues in the module. We pass around pointers to the
+ // mapped HloValues, so the underlying container must keep them valid despite
+ // mutations touching other map entries.
+ std::unordered_map<HloValue::Id, HloValue> values_;
// A map from instruction to InstructionValueSet.
std::unordered_map<const HloInstruction*, InstructionValueSet> value_sets_;
+
+ // A vector containing all HloValues sorted by HloValue::Id.
+ std::vector<const HloValue*> values_vector_;
+
+ // The Id to use for the next HloValue.
+ HloValue::Id next_value_id_ = 0;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
index 9f3dd539ef..ef0fa1d745 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis_test.cc
@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test.h"
@@ -44,8 +43,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
// Run dataflow analysis on the member module. For convenience returns a
// reference to the generated analysis stored in analysis_.
- HloDataflowAnalysis& RunAnalysis(bool ssa_form,
- bool bitcast_defines_value = false) {
+ const HloDataflowAnalysis& RunAnalysis(bool ssa_form,
+ bool bitcast_defines_value = false) {
analysis_ =
HloDataflowAnalysis::Run(module_.get(), ssa_form, bitcast_defines_value)
.ConsumeValueOrDie();
@@ -71,8 +70,8 @@ class HloDataflowAnalysisTest : public HloTestBase,
const HloInstruction* b) {
EXPECT_FALSE(ShapeUtil::IsTuple(a->shape()));
EXPECT_FALSE(ShapeUtil::IsTuple(b->shape()));
- return analysis_->MayInterfere(analysis_->GetValueDefinedAt(a),
- analysis_->GetValueDefinedAt(b), ordering);
+ return ordering.MayInterfere(analysis_->GetValueDefinedAt(a),
+ analysis_->GetValueDefinedAt(b));
}
std::unique_ptr<HloModule> module_;
@@ -499,37 +498,26 @@ TEST_P(HloDataflowAnalysisTest, SingleWhile) {
EXPECT_FALSE(analysis.GetValueDefinedAt(cond_constant).live_out_of_module());
if (ssa_form) {
- // While instruction should define phi values. The value at index {0} is a
- // degenerate phi with a single input 'constant1'.
- EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
- EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{0}),
- &analysis.GetValueDefinedAt(constant1));
- EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
- EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{0}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{0}),
- &analysis.GetValueDefinedAt(constant1));
- EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
- EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{0}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{0}),
- &analysis.GetValueDefinedAt(constant1));
+ // Element 0 of the tuple passed through the body so no phi value is
+ // defined.
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{0}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(body_param, /*index=*/{0}));
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{0}));
+ // Element 1 of the tuple should be a phi value.
EXPECT_TRUE(analysis.ValueIsDefinedAt(xla_while, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(xla_while, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(body_param, /*index=*/{1}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(body_param, /*index=*/{1}), nullptr);
EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param, /*index=*/{1}));
EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param, /*index=*/{1}).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(cond_param, /*index=*/{1}), nullptr);
- EXPECT_THAT(analysis.GetValueDefinedAt(constant1).uses(),
- UnorderedElementsAre(HloUse{xla_while, 0, {0}}));
+ EXPECT_THAT(
+ analysis.GetValueDefinedAt(constant1).uses(),
+ UnorderedElementsAre(HloUse{add, 0, {}}, HloUse{xla_while, 0, {0}}));
- EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{0})
- .live_out_of_module());
+ // Constant1 passes through the body and out of the module.
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while, /*index=*/{1})
.live_out_of_module());
@@ -613,20 +601,15 @@ TEST_P(HloDataflowAnalysisTest, SequentialWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
- if (ssa_form) {
- EXPECT_TRUE(analysis.GetValueDefinedAt(xla_while2).live_out_of_module());
- EXPECT_FALSE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
- } else {
- // Element 0 is passed through all the while instructions and out of the
- // module.
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
- analysis.GetValueDefinedAt(constant1));
- EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
- }
+ // Element 0 is passed through all the while instructions and out of the
+ // module..
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while0, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while1, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_EQ(analysis.GetUniqueValueAt(xla_while2, /*index=*/{0}),
+ analysis.GetValueDefinedAt(constant1));
+ EXPECT_TRUE(analysis.GetValueDefinedAt(constant1).live_out_of_module());
}
TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
@@ -705,13 +688,18 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
bool ssa_form = GetParam();
const HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
+ EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
if (ssa_form) {
EXPECT_TRUE(analysis.ValueIsDefinedAt(inner_param, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
- EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
- EXPECT_TRUE(
- analysis.GetValueDefinedAt(inner_param, /*index=*/{1}).is_phi());
+
+ // Element 0 of the nested while is %negate.
+ EXPECT_FALSE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{0}));
+ EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
+ UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
+ // Element 1 is a phi value (join of %add and %constant2).
EXPECT_TRUE(analysis.ValueIsDefinedAt(nested_while, /*index=*/{1}));
EXPECT_TRUE(
analysis.GetValueDefinedAt(nested_while, /*index=*/{1}).is_phi());
@@ -724,8 +712,6 @@ TEST_P(HloDataflowAnalysisTest, NestedWhiles) {
EXPECT_TRUE(
analysis.GetValueDefinedAt(entry_while, /*index=*/{1}).is_phi());
} else {
- EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(negate)));
EXPECT_THAT(HloValuesAt(inner_param, /*index=*/{1}),
UnorderedElementsAre(analysis.GetValueDefinedAt(add),
analysis.GetValueDefinedAt(constant2)));
@@ -1496,256 +1482,6 @@ TEST_P(HloDataflowAnalysisTest, EmbeddedComputationInterference) {
EXPECT_TRUE(InstructionsMayInterfere(ordering, negate, embedded_log));
}
-TEST_P(HloDataflowAnalysisTest, UpdateAnalysisForWhile) {
- // Test updating dataflow after modifying a module with an array shaped while:
- //
- // body(F32[] %param):
- // %negate = Negate(%param)
- //
- // condition(F32[] %param):
- // return Constant(false)
- //
- // entry:
- // %constant = Constant(1.0)
- // %exp = Exp(%constant)
- // return While(%exp, body, condition)
- //
- auto body_builder = HloComputation::Builder("body");
- auto body_param = body_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
- scalar_shape_, HloOpcode::kNegate, body_param));
- HloComputation* body = module_->AddEmbeddedComputation(body_builder.Build());
-
- // Condition computation trivially returns a constant "false".
- auto cond_builder = HloComputation::Builder("condition");
- auto cond_param = cond_builder.AddInstruction(
- HloInstruction::CreateParameter(0, scalar_shape_, "param"));
- cond_builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- HloComputation* condition =
- module_->AddEmbeddedComputation(cond_builder.Build());
-
- auto builder = HloComputation::Builder(TestName());
- auto constant = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto exp = builder.AddInstruction(
- HloInstruction::CreateUnary(scalar_shape_, HloOpcode::kExp, constant));
- auto xla_while = builder.AddInstruction(
- HloInstruction::CreateWhile(scalar_shape_, condition, body, exp));
- module_->AddEntryComputation(builder.Build());
-
- bool ssa_form = GetParam();
- HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
-
- // Sanity check the initial dataflow analysis before transforming the HLO
- // graph.
- if (ssa_form) {
- EXPECT_TRUE(analysis.ValueIsDefinedAt(body_param));
- EXPECT_TRUE(analysis.GetValueDefinedAt(body_param).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
-
- EXPECT_TRUE(analysis.ValueIsDefinedAt(cond_param));
- EXPECT_TRUE(analysis.GetValueDefinedAt(cond_param).is_phi());
- EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
-
- EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
- EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
- } else {
- EXPECT_THAT(HloValuesAt(body_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
- analysis.GetValueDefinedAt(negate)));
- EXPECT_THAT(HloValuesAt(cond_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
- analysis.GetValueDefinedAt(negate)));
- EXPECT_THAT(HloValuesAt(xla_while),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp),
- analysis.GetValueDefinedAt(negate)));
-
- EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
- }
-
- // Set the body root to the body_param. Previously it was Negate(body_param).
- body->set_root_instruction(body_param);
-
- // Prior to updating, verify that the dataflow analysis is no longer valid.
- Status verify_status = analysis.VerifyAgainstReference();
- EXPECT_FALSE(verify_status.ok());
-
- analysis.UpdateAfterChangingRoot(/*old_root=*/negate,
- /*new_root=*/body_param);
-
- // Analysis should be valid after the update.
- TF_EXPECT_OK(analysis.VerifyAgainstReference());
-
- if (ssa_form) {
- // The phis should now be resolvable as 'exp' is passed through the body
- // transparently.
- EXPECT_EQ(analysis.ResolvePhi(body_param),
- &analysis.GetValueDefinedAt(exp));
- EXPECT_EQ(analysis.ResolvePhi(cond_param),
- &analysis.GetValueDefinedAt(exp));
- EXPECT_EQ(analysis.ResolvePhi(xla_while), &analysis.GetValueDefinedAt(exp));
- EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
- } else {
- EXPECT_THAT(HloValuesAt(body_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
- EXPECT_THAT(HloValuesAt(cond_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
- EXPECT_THAT(HloValuesAt(xla_while),
- UnorderedElementsAre(analysis.GetValueDefinedAt(exp)));
- EXPECT_TRUE(analysis.GetValueDefinedAt(exp).live_out_of_module());
- }
- EXPECT_FALSE(analysis.GetValueDefinedAt(negate).live_out_of_module());
-
- // Now replace the operand of the while with %constant (was %exp).
- TF_ASSERT_OK(exp->ReplaceUseWith(xla_while, constant));
- analysis.UpdateAfterChangingOperand(xla_while, /*old_operand=*/exp,
- /*new_operand=*/constant);
-
- // Verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- if (ssa_form) {
- // The phis now resolve to 'constant'.
- EXPECT_EQ(analysis.ResolvePhi(body_param),
- &analysis.GetValueDefinedAt(constant));
- EXPECT_EQ(analysis.ResolvePhi(cond_param),
- &analysis.GetValueDefinedAt(constant));
- EXPECT_EQ(analysis.ResolvePhi(xla_while),
- &analysis.GetValueDefinedAt(constant));
- } else {
- EXPECT_THAT(HloValuesAt(body_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
- EXPECT_THAT(HloValuesAt(cond_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
- EXPECT_THAT(HloValuesAt(xla_while),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant)));
- EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
- }
-
- // And finally make the negate the root of the body again.
- body->set_root_instruction(negate);
- analysis.UpdateAfterChangingRoot(/*old_root=*/body_param,
- /*new_root=*/negate);
-
- // Verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- if (ssa_form) {
- // Phis should no longer be resolvable.
- EXPECT_EQ(analysis.ResolvePhi(body_param), nullptr);
- EXPECT_EQ(analysis.ResolvePhi(cond_param), nullptr);
- EXPECT_EQ(analysis.ResolvePhi(xla_while), nullptr);
- } else {
- EXPECT_THAT(HloValuesAt(body_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
- analysis.GetValueDefinedAt(negate)));
- EXPECT_THAT(HloValuesAt(cond_param),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
- analysis.GetValueDefinedAt(negate)));
- EXPECT_THAT(HloValuesAt(xla_while),
- UnorderedElementsAre(analysis.GetValueDefinedAt(constant),
- analysis.GetValueDefinedAt(negate)));
-
- EXPECT_FALSE(analysis.GetValueDefinedAt(exp).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(negate).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(constant).live_out_of_module());
- }
-
- // After the updates, verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-}
-
-TEST_P(HloDataflowAnalysisTest, UpdateOfATupleSelect) {
- // Test changing the operands of kSelects of a tuple value and updating the
- // dataflow.
- auto builder = HloComputation::Builder(TestName());
- auto pred = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<bool>(false)));
- auto a = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
- auto b = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(2.0)));
- auto c = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(3.0)));
- auto d = builder.AddInstruction(
- HloInstruction::CreateConstant(Literal::CreateR0<float>(4.0)));
- auto tuple_a = builder.AddInstruction(HloInstruction::CreateTuple({a}));
- auto tuple_b = builder.AddInstruction(HloInstruction::CreateTuple({b}));
- auto tuple_c = builder.AddInstruction(HloInstruction::CreateTuple({c}));
- auto tuple_d = builder.AddInstruction(HloInstruction::CreateTuple({d}));
- const Shape tuple_shape = tuple_a->shape();
- auto select_aa = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_a));
- auto select_ab = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple_a, tuple_b));
- auto select_cd = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, tuple_c, tuple_d));
- auto select_abcd = builder.AddInstruction(HloInstruction::CreateTernary(
- tuple_shape, HloOpcode::kSelect, pred, select_ab, select_cd));
-
- module_->AddEntryComputation(builder.Build());
-
- bool ssa_form = GetParam();
- HloDataflowAnalysis& analysis = RunAnalysis(ssa_form);
-
- // Sanity check dataflow before changing the graph and updating.
- EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a)));
- EXPECT_THAT(HloValuesAt(select_ab, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a),
- analysis.GetValueDefinedAt(b)));
- EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(c),
- analysis.GetValueDefinedAt(d)));
- EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a),
- analysis.GetValueDefinedAt(b),
- analysis.GetValueDefinedAt(c),
- analysis.GetValueDefinedAt(d)));
- EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(c).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
-
- // Set the rhs of 'select_aa' to be 'd'.
- TF_ASSERT_OK(select_aa->ReplaceOperandWith(2, tuple_d));
- analysis.UpdateAfterChangingOperand(select_aa, /*old_operand=*/tuple_a,
- /*new_operand=*/tuple_d);
-
- // Verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- EXPECT_THAT(HloValuesAt(select_aa, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a),
- analysis.GetValueDefinedAt(d)));
-
- // Set the lhs of 'select_cd' to be 'a'.
- TF_ASSERT_OK(select_cd->ReplaceOperandWith(1, tuple_a));
- analysis.UpdateAfterChangingOperand(select_cd, /*old_operand=*/tuple_c,
- /*new_operand=*/tuple_a);
-
- // Verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-
- EXPECT_THAT(HloValuesAt(select_cd, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a),
- analysis.GetValueDefinedAt(d)));
- EXPECT_THAT(HloValuesAt(select_abcd, /*index=*/{0}),
- UnorderedElementsAre(analysis.GetValueDefinedAt(a),
- analysis.GetValueDefinedAt(b),
- analysis.GetValueDefinedAt(d)));
- EXPECT_TRUE(analysis.GetValueDefinedAt(a).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(b).live_out_of_module());
- EXPECT_FALSE(analysis.GetValueDefinedAt(c).live_out_of_module());
- EXPECT_TRUE(analysis.GetValueDefinedAt(d).live_out_of_module());
-
- // After the updates, verify that the dataflow is correct.
- TF_ASSERT_OK(analysis.VerifyAgainstReference());
-}
-
INSTANTIATE_TEST_CASE_P(HloDataflowAnalysisInstantiation,
HloDataflowAnalysisTest,
::testing::Values(false, true));
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index ad6070a9c1..c95e44bd5d 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_scheduling.h"
@@ -218,6 +219,94 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ValuesInWhileComputations) {
+ // Tests the ordering of values (defined by dataflow analysis) in the body and
+ // condition of a while instruction. HLO code:
+ //
+ // body(F32[]) %param):
+ // %negate = Negate(%param)
+ //
+ // condition(F32[] %param):
+ // %convert = Convert<PRED>(%param)
+ //
+ // entry:
+ // %constant = Constant(1.0)
+ // %while = While(%constant, body, condition)
+ // %add = Add(%constant, %while)
+ //
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+
+ auto body_builder = HloComputation::Builder("body");
+ auto body_param = body_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "body_param"));
+ auto negate = body_builder.AddInstruction(HloInstruction::CreateUnary(
+ scalar_shape, HloOpcode::kNegate, body_param));
+ HloComputation* body = module->AddEmbeddedComputation(body_builder.Build());
+
+ auto cond_builder = HloComputation::Builder("condition");
+ auto cond_param = cond_builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "cond_param"));
+ auto convert = cond_builder.AddInstruction(HloInstruction::CreateConvert(
+ ShapeUtil::MakeShape(xla::PRED, {}), cond_param));
+ HloComputation* condition =
+ module->AddEmbeddedComputation(cond_builder.Build());
+
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
+ auto xla_while = builder.AddInstruction(
+ HloInstruction::CreateWhile(scalar_shape, condition, body, constant));
+ auto add = builder.AddInstruction(HloInstruction::CreateBinary(
+ scalar_shape, HloOpcode::kAdd, constant, xla_while));
+ module->AddEntryComputation(builder.Build());
+
+ TF_ASSERT_OK_AND_ASSIGN(
+ auto dataflow, HloDataflowAnalysis::Run(module.get(), /*ssa_form=*/true));
+ DependencyHloOrdering ordering(module.get());
+
+ // Init value is defined before the while, but live range is not before the
+ // while because of the use of the init value in the add.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(ordering.MayInterfere(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ // Any value defined in the body or condition is defined before the while, and
+ // has a live range strictly before the while.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(negate),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+ EXPECT_FALSE(ordering.MayInterfere(dataflow->GetValueDefinedAt(convert),
+ dataflow->GetValueDefinedAt(xla_while)));
+
+ // The live range of the while should be before the add.
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(xla_while),
+ dataflow->GetValueDefinedAt(add)));
+ ASSERT_EQ(dataflow->GetValueDefinedAt(xla_while).uses().size(), 1);
+
+ const HloUse& while_use = dataflow->GetValueDefinedAt(xla_while).uses()[0];
+ EXPECT_EQ(while_use.instruction, add);
+ EXPECT_TRUE(ordering.UseIsBeforeValueDefinition(
+ while_use, dataflow->GetValueDefinedAt(add)));
+ EXPECT_TRUE(
+ ordering.LiveRangeStrictlyBefore(dataflow->GetValueDefinedAt(xla_while),
+ dataflow->GetValueDefinedAt(add)));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index f85d8ec50d..e6cf0d37b8 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -159,12 +159,6 @@ void HloValue::AddPosition(HloInstruction* instruction,
for (const HloPosition& position : positions_) {
DCHECK_NE(position, new_position);
}
- // The shape of the new position must match existing positions.
- if (!positions_.empty()) {
- CHECK(
- ShapeUtil::Compatible(positions_.front().shape(), new_position.shape()))
- << "front: " << positions_.front() << " new: " << new_position;
- }
positions_.push_back(std::move(new_position));
diff --git a/tensorflow/compiler/xla/service/hlo_value.h b/tensorflow/compiler/xla/service/hlo_value.h
index 63ecc25020..6872bc76a8 100644
--- a/tensorflow/compiler/xla/service/hlo_value.h
+++ b/tensorflow/compiler/xla/service/hlo_value.h
@@ -225,6 +225,9 @@ class HloValueSet {
// already exist in the set.
bool AddValue(const HloValue* value);
+ // Clear all values from the set.
+ void Clear() { values_.clear(); }
+
// Return the unique HLO value in the set. CHECKs if the set does not contain
// exactly one value.
const HloValue& GetUniqueValue() const {