diff options
author | 2017-11-03 13:26:47 -0700 | |
---|---|---|
committer | 2017-11-03 13:31:12 -0700 | |
commit | 456929281592f14d50443cfbdaa2f6b36167a134 (patch) | |
tree | 8a1a18245a6ef9baed8bd9a9f35b7c250ab64901 /tensorflow/compiler/xla/service/copy_insertion.cc | |
parent | 5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff) |
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC
BEGIN_PUBLIC
Automated g4 rollback of changelist 174423881
PiperOrigin-RevId: 174505237
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion.cc | 1526 |
1 files changed, 504 insertions, 1022 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index 8f50b29dad..0453a698a0 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -15,17 +15,15 @@ limitations under the License. #include "tensorflow/compiler/xla/service/copy_insertion.h" -#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h" +#include <memory> + #include "tensorflow/compiler/xla/service/hlo_computation.h" -#include "tensorflow/compiler/xla/service/hlo_dce.h" -#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" -#include "tensorflow/compiler/xla/service/hlo_ordering.h" #include "tensorflow/compiler/xla/service/liveness_util.h" #include "tensorflow/compiler/xla/service/logical_buffer.h" -#include "tensorflow/compiler/xla/service/tuple_simplifier.h" +#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/types.h" @@ -33,1113 +31,597 @@ limitations under the License. #include "tensorflow/core/lib/gtl/flatmap.h" #include "tensorflow/core/lib/gtl/flatset.h" #include "tensorflow/core/lib/strings/str_util.h" -#include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/platform/logging.h" namespace xla { -using ::tensorflow::str_util::Join; -using ::tensorflow::strings::StrAppend; -using ::tensorflow::strings::StrCat; - namespace { -bool IsEntryParameterValue(const HloValue& value) { - const HloComputation* computation = value.defining_instruction()->parent(); - return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation(); -} - -bool IsConstantValue(const HloValue& value) { - return value.defining_instruction()->opcode() == HloOpcode::kConstant; -} - -bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsEntryParameterValue(value); -} +using tensorflow::gtl::FlatMap; +using tensorflow::gtl::FlatSet; -// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in -// 'indices_to_copy'. Add control edges from the respective kCopy instructions -// in deep copy of 'from' to the respective kCopy instruction in the deep copy -// of 'to'. -// -// Requirements: 'from' and 'to' must have compatible shapes. +// InstructionCopier encapsulates indices at which to copy 'instruction'. +// All 'instruction' users in 'copy_users' are updated to use the copy. // -// For example, suppose 'from' and 'to' are two-element tuples where index 0 is -// the only index to copy. Prior to deep-copying we have: +// Instruction copies are generated in two phases: +// 1) Recording buffer indices at which 'instruction' requires copies (i.e. +// setting 'indices_to_copy_[index]'=true). +// 2) Inserting kCopy instructions based on indices recorded in phase 1). +// *) Array instructions are copied by inserting a single kCopy instruction. +// *) Tuple-shaped instructions are copied by recursively expanding tuples +// (and tuple-shaped elements), and inserting kCopy instructions for any +// tuple elements which require a copy. As the recursion unwinds, new tuple +// instructions are added to gather the copied (and uncopied) references +// into the output tuple (i.e. the copy of the tuple-shaped instruction). // +// Example two-element tuple with one element that needs a copy: // -// 'from' -// | -// ... -// | -// 'to' +// original-instruction +// / \ +// GTE(0) GTE(1) +// | | +// Copy | +// \ / +// Tuple // copied-instruction // -// DeepCopyAndAddControlEdges produces: -// -// 'from' -// / \ -// GTE GTE -// | | -// Copy | -// / \ / -// | Tuple -// | | -// ctrl ... -// edge | -// | | -// | 'to' -// | / \ -// | GTE GTE -// \ | | -// Copy | -// \ / -// Tuple -// -StatusOr<std::pair<HloInstruction*, HloInstruction*>> -DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to, - const ShapeTree<bool>& indices_to_copy) { - DCHECK(ShapeUtil::Compatible(from->shape(), to->shape())); - // to/from_copy_tree hold the kCopy instruction produces by the deep - // copies. Elements which are not copied (indices_to_copy.element(index) == - // false) have nullptr at that index. - ShapeTree<HloInstruction*> from_copy_tree(from->shape(), - /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy, - from->parent()->DeepCopyInstruction( - from, &indices_to_copy, &from_copy_tree)); - - ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN( - HloInstruction * to_deep_copy, - to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree)); - - // Add control edges between the respective kCopy instructions. - for (const auto& pair : from_copy_tree) { - const ShapeIndex& index = pair.first; - HloInstruction* from_copy = pair.second; - HloInstruction* to_copy = to_copy_tree.element(index); - if (from_copy == nullptr) { - TF_RET_CHECK(to_copy == nullptr); - continue; - } - TF_RET_CHECK(to_copy != nullptr); - TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy)); +// As an optimization, if the original instruction is itself a Tuple +// instruction, we elide the unnecessary extra GTE and Tuple instructions, +// and just insert the copy into a new Tuple instruction, with control +// dependencies to ensure the copy occurs after any possible interference. +class InstructionCopier { + public: + InstructionCopier(HloInstruction* instruction, + const std::vector<HloInstruction*>& copy_users) + : instruction_(instruction), + copy_users_(copy_users), + indices_to_copy_(instruction->shape()), + control_predecessors_(instruction->shape()) {} + + // Sets indices that are read-only, and thus do not need to be copied. + void SetReadOnlyIndices(const ShapeTree<bool>& read_only_indices) { + read_only_indices_ = read_only_indices; } - return std::make_pair(from_deep_copy, to_deep_copy); -} - -// Compute the indices of the loop state which need copies in order to avoid -// live range interference. Generally, an element in the loop state does not -// need to be copied if the element is passed through transparently through the -// body. -// -// Returns whether any indices need to be copied. -bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow, - const HloInstruction* xla_while, - ShapeTree<bool>* indices_to_copy) { - DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape())); - - bool any_copies = false; - const HloInstruction* init = xla_while->operand(0); - for (auto& pair : *indices_to_copy) { - const ShapeIndex& index = pair.first; - bool& should_copy = pair.second; - // If there is any ambiguity, then loop state must be copied. - if (dataflow.GetValueSet(init, index).values().size() > 1 || - dataflow.GetValueSet(xla_while, index).values().size() > 1) { - should_copy = true; - } else { - // If the output of the while instruction is not the same as the init - // value of the while, then this element is not passed through the body - // transparently and must be copied. - should_copy = dataflow.GetUniqueValueAt(xla_while, index) != - dataflow.GetUniqueValueAt(init, index); - } - any_copies |= should_copy; + // Sets copy overrides, which are copy instructions to use at each index. This + // is used to share a single copy of read-only entry parameters and constants + // between multiple While loops. + void SetCopyOverrides(const ShapeTree<HloInstruction*>& copy_overrides) { + copy_overrides_ = copy_overrides; } - return any_copies; -} - -// Add kCopy instructions around the given kWhile instruction to eliminate any -// possible live range interference of HLO values assuming a dependency-based -// ordering (HloDependencyOrdering). Copies are added conservatively. There -// likely are copies which are not strictly necessary, but there are removed -// later in the pass via CopyRemover. -// -// -// Elements (each ShapeIndex) in the loop state are considered independently. A -// copy is added to each element of the loop state which is modified in the -// while body. For each such element, a total of three kCopy instructions are -// added at following locations: -// -// (1) The init value is copied before the kWhile instruction. Before: -// -// (Init) -// | -// kWhile -// | -// ... -// -// After: -// -// (Init) -// | -// kCopy -// | -// kWhile -// | -// ... -// -// This copy is necessary in case the init value is simultaneously live -// with the kWhile. -// -// (2) Copies are added to the parameter and root of the while body -// computation. Before: -// -// kParameter -// | -// ... -// | -// (body root) -// -// After: -// -// kParameter -// | -// kCopy ----------+ -// | | -// ... ctrl -// | edge -// (body root) | -// | | -// kCopy <---------+ -// -// The root kCopy becomes the new root of the computation. Both copies are -// necessary to any potential interference between the parameter value and -// the root value. The control edge prevents potential interference -// between the copies themselves. -// -// If the loop state is a tuple then the above kCopy instructions are a deep -// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as -// constructed by HloInstruction::DeepCopyInstruction. -Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis, - HloInstruction* xla_while) { - VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name(); - TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile); - ShapeTree<bool> indices_to_copy(xla_while->shape()); - if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while, - &indices_to_copy)) { - VLOG(2) << "No copies necessary for kWhile instruction " - << xla_while->name(); - return Status::OK(); - } + // Returns true if all recorded indices are false (returns true otherwise). + bool HasAllIndicesFalse() const; - VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:"; - for (auto& pair : indices_to_copy) { - if (pair.second) { - VLOG(2) << " " << pair.first; - } - } + // Records instruction buffer indices which point-to a Parameter or Constant. + Status RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis); - // Deep copy init. - HloInstruction* while_init = xla_while->mutable_operand(0); - TF_ASSIGN_OR_RETURN( - HloInstruction * while_init_copy, - xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy)); - TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy)); + // Records instruction buffer indices to copy which are necessary to ensure: + // *) PointsToSet of 'instruction_' is unambiguous and distinct. + // *) No liveness interference between 'instruction_' and 'other_instruction'. + // + // If 'read_only_indices_out' is non-null, read-only indices are set to true. + Status RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out); - // Deep copy the parameter and the root. Extend a control edge from the copy - // of the parameter value to the corresponding copy value of the root. - HloComputation* body = xla_while->while_body(); - HloInstruction* param = body->parameter_instruction(0); - HloInstruction* root = body->root_instruction(); + // Records control predecessors to add for inserted copy instructions. + // 'parameter' must have the same shape as the instruction that will be + // copied, and must define all buffers in the shape. Control predecessors are + // only recorded for indices that have already been marked for copying. + Status RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter); - // If param is the root then all indices should have been passed through the - // while body and we should have returned early above. - TF_RET_CHECK(param != root); + // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy', + // and replaces all uses for instructions in 'copy_users_' with copy. + // Returns the instruction which is a copy 'instruction'. + HloInstruction* Copy(); - // Copy users before making a deep copy of the parameter as the deep copy - // will create new users of the parameter (eg, the GTE instructions of the - // deep copy). - std::vector<HloInstruction*> param_users = param->users(); + HloInstruction* instruction() { return instruction_; } - ShapeIndex current_index; - TF_ASSIGN_OR_RETURN(auto pair, - DeepCopyAndAddControlEdges(param, root, indices_to_copy)); + const std::vector<HloInstruction*>& copy_users() const { return copy_users_; } - HloInstruction* param_copy = pair.first; - HloInstruction* root_copy = pair.second; + private: + // Does the given index represent a read-only buffer? + bool IsReadOnlyIndex(const ShapeIndex& index) const { + return !ShapeUtil::IsNil(read_only_indices_.shape()) && + read_only_indices_.element(index); + } - for (HloInstruction* user : param_users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy)); + // Returns the copy override at the given index, or nullptr. + HloInstruction* GetCopyOverride(const ShapeIndex& index) const { + return ShapeUtil::IsNil(copy_overrides_.shape()) + ? nullptr + : copy_overrides_.element(index); } - body->set_root_instruction(root_copy); + // Records instruction buffer indices which have ambiguous or non-distinct + // points-to sets. + Status RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis); - return Status::OK(); -} + // Records instruction buffer indices which have interfering live ranges + // with 'other_instruction' buffers at same index. + Status RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out); -// Removes any control dependencies to or from the given instruction. -Status StripControlDependenciesFrom(HloInstruction* instruction) { - while (!instruction->control_successors().empty()) { - TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo( - instruction->control_successors().front())); - } + // Recursively inserts copies of 'instruction' tuple elements at indices + // specified in 'indices_to_copy', and returns the copy of 'instruction'. + HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index); - while (!instruction->control_predecessors().empty()) { - TF_RETURN_IF_ERROR( - instruction->control_predecessors().front()->RemoveControlDependencyTo( - instruction)); + void RecordIndex(const ShapeIndex& index) { + *indices_to_copy_.mutable_element(index) = true; } - return Status::OK(); -} - -// Add kCopy instructions to the given module to guarantee there is no -// live-range interference. Generally interference can only occur around kWhile -// instructions which have update-in-place semantics. -Status AddCopiesToResolveInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); + HloInstruction* instruction_; + const std::vector<HloInstruction*> copy_users_; + ShapeTree<bool> indices_to_copy_; + ShapeTree<std::vector<HloInstruction*>> control_predecessors_; + ShapeTree<bool> read_only_indices_; + ShapeTree<HloInstruction*> copy_overrides_; +}; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kWhile) { - TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction)); - } - } - } - return Status::OK(); +bool InstructionCopier::HasAllIndicesFalse() const { + bool all_indices_false = true; + indices_to_copy_.ForEachElement( + [&all_indices_false](const ShapeIndex& /*index*/, bool data) { + if (data) { + all_indices_false = false; + } + }); + return all_indices_false; } -// Class for removing unnecessary copies from the module. -// -// kCopy instructions are added conservatively to guarantee no live range -// interference between HLO values. This class uses a more fine-grained analysis -// to remove some of these added copies which are not strictly necessary. -class CopyRemover { - public: - CopyRemover(const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering, HloModule* module) - : module_(module), - alias_analysis_(alias_analysis), - ordering_(ordering), - buffer_value_tracker_(*module, alias_analysis, ordering) {} - - // Try to elide the given copy. The copy is elided if the instruction is not - // necessary to prevent live-range interference of HLO values. Returns true if - // copy was elided. - // - // The copy instruction is not actually removed here. Instead it is left for - // dead in the graph. Later calls to DCE will remove the instruction. - StatusOr<bool> TryElideCopy(HloInstruction* copy) { - if (buffer_value_tracker_.TryElideCopy(copy)) { - TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy)); - TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0))); - return true; - } - return false; +Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Shallow copy the instruction if the points-to set of the top-level + // buffer is ambiguous. This is necessary because the backends must know + // statically what the top-level buffer of the result is. + if (points_to.element(/*index=*/{}).size() > 1) { + RecordIndex({}); } - string ToString() const { - string out = StrCat("CopyRemover, module ", module_->name(), "\n"); - StrAppend(&out, " Buffer values, in dependency order:\n"); - for (const HloBuffer& buffer : alias_analysis_.buffers()) { - StrAppend(&out, " HloBuffer ", buffer.id(), ":\n"); + // Multiple buffers within a parameter/constant may be live out, so collect + // a set of indices at which to copy first. + points_to.ForEachElement([this](const ShapeIndex& index, + const PointsToSet::BufferList& buffers) { + if (IsReadOnlyIndex(index)) { + return; } - return out; - } - - private: - // Class which tracks the HLO values within each HLO buffer in the module - // during copy removal. - // - // The values are held in a linked list where there is one list for each - // buffer. Removing a copy instruction merges together the values in the - // source buffer of the copy to the destination buffer of the copy. This class - // tracks these value lists as copies are removed from the graph (and value - // lists are merged). - // - // The BufferValueTracker object is initialized to match the state of - // HloAliasAnalysis. However, as copies are removed this state diverges. The - // values-to-buffer mapping is maintained outside of HloAliasAnalysis because - // a fully updatable alias analysis is very slow. - class BufferValueTracker { - public: - // The values held in a single HLO buffer are represented using a linked - // list. An element type in this list is ValueNode. - // - // This linked list is hand-rolled to enable efficient splicing of lists - // using only references to list elements without knowing which lists are - // being spliced. std::list requires a reference to the list object to - // splice. - struct ValueNode { - explicit ValueNode(const HloValue* v) : value(v) {} - - const HloValue* value; - - // The uses are maintained outside of HloValue::uses() because - // HloValue::uses() is not updatable (a fully updatable dataflow analysis - // is slow). - std::vector<const HloUse*> uses; - - // next/prev elements in the linked list. The list is circularly linked so - // these values are never null for elements in the list. - ValueNode* prev = nullptr; - ValueNode* next = nullptr; - }; - - BufferValueTracker(const HloModule& module, - const HloAliasAnalysis& alias_analysis, - const HloOrdering& ordering) - : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) { - // Construct a list for each HLO buffer in the alias analysis. Maintain a - // map from HloValue to the respective list element representing that - // value. The map is used to construct the copy info map below. - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node; - for (const HloBuffer& buffer : alias_analysis.buffers()) { - // Verify values contained in the buffer are strictly ordered. This - // should always be the case after adding copies to eliminate - // interference. Specifically, the addition of the control flow edges - // between copies added around aliased operations (kWhile) guarantees - // this strict order. - for (const HloValue* value_a : buffer.values()) { - for (const HloValue* value_b : buffer.values()) { - if (value_a != value_b) { - DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b, - dataflow_) || - ordering_.LiveRangeStrictlyBefore(*value_b, *value_a, - dataflow_)) - << value_a->ToShortString() << " and " - << value_b->ToShortString() << " are not ordered"; - } - } - } - - std::vector<const HloValue*> values = buffer.values(); - std::sort(values.begin(), values.end(), - [this](const HloValue* a, const HloValue* b) { - return ordering_.IsDefinedBefore(*a, *b); - }); - - // Create a list containing all of the values in the buffer. - AddValueList(values, &value_to_node); + for (const LogicalBuffer* buffer : buffers) { + // pointee is the HloInstruction producing the buffer which may be + // liveout. + HloInstruction* pointee = buffer->instruction(); + if (pointee->opcode() == HloOpcode::kParameter || + pointee->opcode() == HloOpcode::kConstant) { + VLOG(2) << "Parameter or constant buffer " << buffer->ToString() + << " index: " << tensorflow::str_util::Join(index, ",") + << " may be live out of computation: " << pointee->ToString(); + RecordIndex(index); + break; } - - // Create copy_map_ which contains the source and destination values - // of all copies. - CreateCopyMap(module, value_to_node); - - XLA_VLOG_LINES(3, ToString()); - TF_DCHECK_OK(Verify()); } + }); + return Status::OK(); +} - // Add a list containing the given values to BufferValueTracker. This - // represents the values contained in a single buffer. For each value in - // 'values' an entry is created in value_to_node which indicates the - // respective ValueNode representing that value. - void AddValueList( - tensorflow::gtl::ArraySlice<const HloValue*> values, - tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) { - ValueNode* tail = nullptr; - ValueNode* head = nullptr; - for (const HloValue* value : values) { - auto new_node = new ValueNode(value); - (*value_to_node)[value] = new_node; - - // Copy the HLO values's uses into the ValueNode for the value. These - // uses in ValueNode are updated as copies are removed. - new_node->uses.reserve(value->uses().size()); - for (const HloUse& use : value->uses()) { - new_node->uses.push_back(&use); - } - - // Connect the new node into the linked list. - if (tail == nullptr) { - head = new_node; - } else { - tail->next = new_node; - new_node->prev = tail; - } - tail = new_node; - } - - // The linked list is circular so connect the head and tail. - tail->next = head; - head->prev = tail; - value_lists_.insert(head); - } +Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out) { + TF_RETURN_IF_ERROR( + RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis())); + TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction( + liveness, other_instruction, read_only_indices_out)); + return Status::OK(); +} - // This method also fills in copy_map_ which indicates which nodes - // in the value lists corresponding to the source and destination values of - // kCopy instructions. value_to_node should map each HloValue to its - // respective ValueNode. - void CreateCopyMap( - const HloModule& module, - const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>& - value_to_node) { - for (HloComputation* computation : module.computations()) { - for (HloInstruction* instruction : computation->instructions()) { - // Add copies with unambiguous source values to the map. Copies with - // ambiguous sources are not removable. - if (instruction->opcode() == HloOpcode::kCopy) { - const HloValueSet& src_value_set = - dataflow_.GetValueSet(instruction->operand(0)); - if (src_value_set.values().size() == 1) { - CopyNodes& copy_node = copy_map_[instruction]; - copy_node.dest = - value_to_node.at(&dataflow_.GetUniqueValueAt(instruction)); - copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue()); - } +Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices( + const TuplePointsToAnalysis& points_to_analysis) { + const PointsToSet& points_to = + points_to_analysis.GetPointsToSet(instruction_); + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>> + buffer_to_source_indices; + points_to.ForEachElement( + [this, &buffer_to_source_indices]( + const ShapeIndex& index, const PointsToSet::BufferList& buffers) { + if (buffers.size() > 1) { + // Record ambiguous points-to set at 'index'. + if (!indices_to_copy_.element(index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " with ambiguous points-to set."; + RecordIndex(index); } } - } - } + // For each 'buffer': record a mapping from 'buffer' to 'index'. + for (const LogicalBuffer* buffer : buffers) { + buffer_to_source_indices[buffer].push_back(index); + } + }); - ~BufferValueTracker() { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - const ValueNode* tmp = p->next; - delete p; - p = tmp; - } while (p != head); - } + // Record all non-distinct indices detected in 'buffer_to_source_indices'. + for (const auto& buff_to_src : buffer_to_source_indices) { + if (buff_to_src.second.size() == 1) { + continue; } - - // Verify invariants within the linked lists. - Status Verify() const { - for (const ValueNode* head : value_lists_) { - const ValueNode* p = head; - do { - // Verify links between elements are consistent. - TF_RET_CHECK(p->prev->next == p); - TF_RET_CHECK(p->next->prev == p); - - const HloInstruction* def = p->value->defining_instruction(); - if (def->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, def)) { - TF_RET_CHECK(copy_map_.at(def).dest == p); - } - for (const HloUse* use : p->uses) { - if (use->instruction->opcode() == HloOpcode::kCopy && - ContainsKey(copy_map_, use->instruction)) { - TF_RET_CHECK(copy_map_.at(use->instruction).src == p); - } - } - - p = p->next; - } while (p != head); + for (const ShapeIndex& src_index : buff_to_src.second) { + // Record non-distinct points-to set at 'src_index'. + if (!indices_to_copy_.element(src_index)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " at index: " << tensorflow::str_util::Join(src_index, ",") + << " because of non-distinct points-to set."; + RecordIndex(src_index); } - return Status::OK(); } + } + return Status::OK(); +} - // Try to elide the given copy. Elision of a copy is possible only if no - // live range interference is introduced by the copy's elimination. If - // elision is possible, then the internal state (value lists) are updated, - // and true is returned. Returns false otherwise. - bool TryElideCopy(const HloInstruction* copy) { - VLOG(2) << "Trying to remove " << copy->name(); - - if (!ContainsKey(copy_map_, copy)) { - VLOG(2) << copy->name() << " is not removable"; - return false; - } - - const CopyNodes& copy_node = copy_map_.at(copy); - ValueNode* src = copy_node.src; - ValueNode* dest = copy_node.dest; - DCHECK(src != nullptr); - DCHECK(dest != nullptr); - - auto is_live_range_before = [this](const ValueNode& a, - const ValueNode& b) { - if (LiveRangeBefore(a, b)) { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is before " << b.value->ToShortString(); - return true; - } else { - VLOG(2) << " Live range of " << a.value->ToShortString() - << " is not before " << b.value->ToShortString(); - return false; +Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction( + const BufferLiveness& liveness, const HloInstruction* other_instruction, + ShapeTree<bool>* read_only_indices_out) { + // Record all buffer indices for 'instruction_', which interfere with + // 'other_instruction' at the same index. + ShapeUtil::ForEachSubshape( + instruction_->shape(), + [this, &liveness, other_instruction, read_only_indices_out]( + const Shape& /*subshape*/, const ShapeIndex& index) { + if (IsReadOnlyIndex(index)) { + return; } - }; - - // A kCopy instruction copies an HLO value from a source buffer and - // defines an HLO value in a destination buffer. Most generally, the - // source and destination buffers may each hold more than one value at - // different points in the computation so we define the following: - // - // Values in source buffer: {s_0, ..., s_n} - // Values in destination buffer: {d_0, ..., d_m} - // - // A kCopy instruction between these buffers copies a value s_x in the - // source buffer and defines a value d_y in the destination buffer. The - // elision of a copy merges the source and destination buffers together, - // so the list of values for the source and destination buffers are - // merged. - // - // We handle two different cases for copy elision: - // - // (1) the kCopy defines the first value in the destination buffer (d_0). - // - // (2) the kCopy copies the last value in the source buffer (s_n). - // - // For the remaining case where the kCopy copies a not-last value from the - // source buffer to a not-first value of the destination buffer, the kCopy - // instruction cannot be removed. This case is generated, for example, if - // the kCopy copies a while body parameter of the loop state at one tuple - // index to a different tuple index in the while body root. Removal of the - // copy necessarily results in live range interference of values in the - // loop state at the two different tuple indices. - // - // We can only perform copy elision if the resulting merged values have - // totally ordered live ranges; otherwise the merged buffer would have - // live range interference. - if (IsHead(*dest)) { - // The copy copies an arbitrary value in the source buffer (call it s_x) - // and defines d_0, the first value in the destination buffer. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n} - // - // Removing the copy eliminates d_0, and uses of d_0 become uses of - // s_x. In the above ordering, the live range of d_m must be ordered - // before the live range of s_{x+1} and the definition and all uses of - // s_x must be ordered before the definition of d_1. These conditions - // are checked below prior to elision. - // - // ** Technically it might be possible to have a non-interfering - // non-trivial interleaving of the values of the source and - // destination buffers in the resulting order. However, this case is - // slow and complicated to check and likely not worth it. So instead - // we simply check for the case where *all* values of the destination - // buffer (d_1 through d_m) are spliced into the point where the copy - // used to be. - VLOG(2) << copy->name() << " defines the first value in its buffer"; - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - // Live range of 'from' value (s_x) must be before 'next_dest' (d_1); - if (!is_live_range_before(*src, *next_dest)) { - return false; - } + if (indices_to_copy_.element(index)) { + // Return if previous pass already set index. + return; } - ValueNode* next_src = Next(*src); - - if (next_src != nullptr) { - // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}. - ValueNode* last_dest = dest->prev; - DCHECK(IsTail(*last_dest)); - if (!is_live_range_before(*last_dest, *next_src)) { - return false; + const auto& points_to_analysis = liveness.points_to_analysis(); + // Lookup buffers for 'instruction_' and 'other_instruction'. + const auto instruction_buffers = + points_to_analysis.GetPointsToSet(instruction_).element(index); + // If 'instruction_' has ambiguous points-to-set at 'index', it would + // have been recorded in a previous pass (and we would have returned + // early at the entry to this function). As a result, here we know that + // 'instruction_' has just one buffer in its points-to-set. + CHECK_EQ(1, instruction_buffers.size()); + const LogicalBuffer* instruction_buffer = instruction_buffers[0]; + + const auto other_instruction_buffers = + points_to_analysis.GetPointsToSet(other_instruction).element(index); + // Do not insert a copy if both instructions point at the same buffer. + // This eliminates unnecessary copies of read-only tuple elements. + // If 'instruction_' and 'other_instruction' point to the same buffer, + // then that buffer is not updated on the path between the two + // instructions. Therefore, any other (possibly interference-causing) + // users of that buffer from 'other_instruction' will see the same data, + // irrespective of whether we insert a copy of this buffer at + // 'instruction_' or not. + if (other_instruction_buffers.size() == 1 && + other_instruction_buffers[0]->id() == instruction_buffer->id()) { + if (read_only_indices_out != nullptr) { + *read_only_indices_out->mutable_element(index) = true; } + return; } - - // Splice in destination buffer values list right after 'src'. - SpliceAfter(dest, src); - } else if (IsTail(*src)) { - // The copy copies the last value in the source buffer, s_n, and defines - // an arbitrary value in the destination buffer, d_y. After - // merging, the values in the combined buffer must be strictly ordered - // as follows** to elide the copy: - // - // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m} - // - // Removing the copy eliminates d_y, and uses of d_y become uses of - // s_n. To enforce the above order, the live range of d_{y-1} must be - // before the live range of s_0, and the live range of s_n must be - // before the live range of d_{y+1}. - // - // ** See comment above in the code handling Case (1). - VLOG(2) << copy->name() << " copies the last value (" - << src->value->ToShortString() << ") in its buffer"; - - ValueNode* prev_dest = Prev(*dest); - // nullptr condition handled above in the first 'if' case. - DCHECK(prev_dest != nullptr); - ValueNode* first_src = src->next; - DCHECK(IsHead(*first_src)); - if (!is_live_range_before(*prev_dest, *first_src)) { - // Live range of value d_{y-1} is not before s_0. - return false; - } - ValueNode* next_dest = Next(*dest); - if (next_dest != nullptr) { - if (!is_live_range_before(*src, *next_dest)) { - // Live range of value s_n is not before d_{y+1}. - return false; + // We can't say anything about the ambiguity of 'other_instruction' at + // this point, so we need to check interference between the single + // buffer in the points-to set of 'instruction_' and all buffers in + // 'other_instruction_buffers'. + for (const LogicalBuffer* other_buffer : other_instruction_buffers) { + if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) { + VLOG(2) << "Adding copy of buffer for instruction: " + << instruction_->name() + << " instruction_buffer: " << instruction_buffer->ToString() + << " at index: " << tensorflow::str_util::Join(index, ",") + << " because of interference with buffer: " + << other_buffer->ToString(); + RecordIndex(index); + break; } } + }); + return Status::OK(); +} - // Splice source buffer values list right after 'prev_dest'. - SpliceAfter(first_src, prev_dest); - } else { - VLOG(2) - << copy->name() - << " copies value in middle of source buffer to value in middle " - "of destination buffer"; - return false; - } - - RemoveCopyValue(dest); - - XLA_VLOG_LINES(4, ToString()); - TF_DCHECK_OK(Verify()); - - return true; - } - - // Delete the given ValueNode associated with a elided kCopy - // instruction. This should be called after splicing the value lists of the - // source and destination buffers together. - void RemoveCopyValue(ValueNode* copy_value_node) { - CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(), - HloOpcode::kCopy); - ValueNode* operand_node = copy_value_node->prev; - CHECK(operand_node != copy_value_node); - - VLOG(2) << "Removing copy " << operand_node->value->ToShortString() - << " => " << copy_value_node->value->ToShortString(); - - // Splice out the copy value node. - operand_node->next = copy_value_node->next; - copy_value_node->next->prev = operand_node; - - // Patch up uses. Remove use of copy from operand_node uses. - auto it = - std::find_if(operand_node->uses.begin(), operand_node->uses.end(), - [copy_value_node](const HloUse* use) { - return use->instruction == - copy_value_node->value->defining_instruction(); - }); - CHECK(it != operand_node->uses.end()); - operand_node->uses.erase(it); - - // If the elided copy has any uses which are themselves kCopy instructions - // then patch up the copy info to reflect the that this kCopy instruction - // has a different operand (the operand of the elided copy). - for (const HloUse* copy_use : copy_value_node->uses) { - operand_node->uses.push_back(copy_use); - if (copy_use->instruction->opcode() == HloOpcode::kCopy) { - copy_map_.at(copy_use->instruction).src = operand_node; - } - } - - // Delete the copy info and the value node. - copy_map_.erase(copy_value_node->value->defining_instruction()); - delete copy_value_node; - } - - // Returns true if the live range of given value 'a' is before the live - // range of 'b'. - // - // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not - // updated as copies are removed. - bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) { - if (a.uses.empty()) { - VLOG(2) << "Empty uses"; - return ordering_.IsDefinedBefore(*a.value, *b.value); - } - for (const HloUse* use : a.uses) { - VLOG(2) << "use: " << *use; - VLOG(2) << "is before:" << *b.value; - if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) { - VLOG(2) << "Not before"; - return false; +// This is called when 'instruction_' is a while body root, and 'parameter' is +// the while body parameter. We record all users of all aliases of 'parameter' +// as control predecessors, so that when we add a copy of 'instruction_', we can +// mark the control dependencies. This is necessary because points-to and +// liveness analysis doesn't know about the aliasing between the while body root +// and param. Without these control dependencies, the copy might get scheduled +// to run at a point that interferes with users of the buffer. +Status InstructionCopier::RecordControlPredecessors( + const TuplePointsToAnalysis& points_to_analysis, + HloInstruction* parameter) { + return indices_to_copy_.ForEachElementWithStatus( + [this, &points_to_analysis, parameter](const ShapeIndex& index, + bool will_copy) { + if (will_copy) { + TF_ASSIGN_OR_RETURN( + const LogicalBuffer* buffer, + points_to_analysis.GetBufferDefinedAt(parameter, index)); + for (const BufferAlias& alias : + points_to_analysis.GetBufferAliases(*buffer)) { + for (HloInstruction* user : alias.instruction()->users()) { + if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(), + user, points_to_analysis)) { + continue; + } + + if (user != instruction_) { + control_predecessors_.mutable_element(index)->push_back(user); + } + } + } } - } - return true; - } - - // Returns whether 'node' is the last node in its list. - bool IsTail(const ValueNode& node) const { - return ContainsKey(value_lists_, node.next); - } - - // Returns whether 'node' is the first node in its list. - bool IsHead(const ValueNode& node) const { - return ContainsKey(value_lists_, &node); - } - - // Returns the next node in the list after 'node'. If 'node' is the - // tail, then nullptr is returned. - ValueNode* Next(const ValueNode& node) const { - if (IsTail(node)) { - return nullptr; - } else { - return node.next; - } - } - - // Returns the previous node in the list before 'node'. If 'node' - // is the head, then nullptr is returned. - ValueNode* Prev(const ValueNode& node) const { - if (IsHead(node)) { - return nullptr; - } else { - return node.prev; - } - } - - // Splices the entire linked list with 'head' as its head right after the - // node 'insert_after' in another linked list. - void SpliceAfter(ValueNode* head, ValueNode* insert_after) { - DCHECK(IsHead(*head)); - value_lists_.erase(head); - - ValueNode* tail = head->prev; - tail->next = insert_after->next; - insert_after->next->prev = tail; - - insert_after->next = head; - head->prev = insert_after; - } - - string ToString() const { - string out = StrCat("BufferValueTracker:\n"); - StrAppend(&out, " Def-use chains in each buffer:\n"); - for (const ValueNode* head : value_lists_) { - StrAppend(&out, " Buffer defined by ", head->value->ToShortString(), - ":\n"); - const ValueNode* p = head; - do { - StrAppend(&out, " ", p->value->ToShortString(), ", uses: ", - Join(p->uses, "; ", - [](string* s, const HloUse* use) { - StrAppend(s, use->ToString()); - }), - "\n"); - - p = p->next; - } while (p != head); - } - StrAppend(&out, " Potentially removable copies:\n"); - for (const auto& pair : copy_map_) { - const HloInstruction* copy = pair.first; - const CopyNodes& copy_info = pair.second; + return Status::OK(); + }); +} - StrAppend(&out, " ", copy->name(), " : ", - copy_info.src->value->ToShortString(), " => ", - copy_info.dest->value->ToShortString(), "\n"); - } - return out; +// Recursively inserts copies of 'instruction' tuple element buffers at +// indices in 'indices_to_copy_', expanding tuples as needed. +HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction, + ShapeIndex* index) { + const int64 num_tuple_elements = + ShapeUtil::TupleElementCount(instruction->shape()); + std::vector<HloInstruction*> elem_copies(num_tuple_elements); + for (int64 i = 0; i < num_tuple_elements; ++i) { + HloInstruction* elem; + if (instruction->opcode() == HloOpcode::kTuple) { + // If the instruction is already a Tuple instruction, we know that the + // element buffers are aliased, so we can just grab the operand directly. + elem = instruction->mutable_operand(i); + } else { + // Otherwise we need to add a GTE to unpack the element out of the tuple. + elem = instruction->parent()->AddInstruction( + HloInstruction::CreateGetTupleElement( + ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction, + i)); } - - private: - const HloDataflowAnalysis& dataflow_; - const HloOrdering& ordering_; - - // The heads of all the value lists. Each value list represents the HLO - // values contained in a particular HLO buffer. The values in the list are - // in dependency order. - tensorflow::gtl::FlatSet<const ValueNode*> value_lists_; - - // Copy removal requires fast access to the value list elements - // corresponding to the source and destination values of the kCopy - // instruction. This data structure holds pointers to these elements for - // each kCopy instruction in the graph. - struct CopyNodes { - // The source and destinations values of the kCopy instruction. - ValueNode* src = nullptr; - ValueNode* dest = nullptr; - }; - tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_; - }; - - HloModule* module_; - const HloAliasAnalysis& alias_analysis_; - const HloOrdering& ordering_; - - // Object tracking the HLO values contained in each HLO buffer. - BufferValueTracker buffer_value_tracker_; -}; - -// Try to remove as many copies from the module as possible without introducing -// live range interference. Copy instructions (identified by their unique id) in -// the set copies_to_exclude are not considered for removal. -Status RemoveUnnecessaryCopies( - const HloOrdering& ordering, - const tensorflow::gtl::FlatSet<HloInstruction::Id>& copies_to_exclude, - HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - CopyRemover copy_remover(*alias_analysis, ordering, module); - XLA_VLOG_LINES(3, copy_remover.ToString()); - - tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy && - !ContainsKey(copies_to_exclude, instruction->unique_id())) { - TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status()); + index->push_back(i); + if (ShapeUtil::IsTuple(elem->shape())) { + elem_copies[i] = CopyTuple(elem, index); + } else if (!indices_to_copy_.element(*index)) { + elem_copies[i] = elem; + } else if (HloInstruction* copy_override = GetCopyOverride(*index)) { + elem_copies[i] = copy_override; + } else { + HloInstruction* elem_copy = elem->parent()->AddInstruction( + HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem)); + for (HloInstruction* control_predecessor : + control_predecessors_.element(*index)) { + VLOG(2) << "Adding control dependency from " + << control_predecessor->ToString() << " to " + << elem_copy->ToString(); + TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy)); } + elem_copies[i] = elem_copy; } + index->pop_back(); } - - return Status::OK(); + return instruction->parent()->AddInstruction( + HloInstruction::CreateTuple(elem_copies)); } -// Add copies to address special constraints on the roots of computations not -// related to live range interference: -// -// (1) Entry computation root must be unambiguous and distinct. -// -// (2) Any computation called by a kCall instruction must have an -// unambiguous root. -// -// (3) Constants and parameters cannot be live out of the entry computation -// -Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - - // Identify which shape indices of which instructions need to be copied. Store - // these results in 'instructions_to_copy'. - std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy; - auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction, - const ShapeIndex& index) { - auto it = instructions_to_copy.find(instruction); - if (it == instructions_to_copy.end()) { - auto it_added = instructions_to_copy.emplace( - std::piecewise_construct, std::forward_as_tuple(instruction), - std::forward_as_tuple(instruction->shape(), /*init_value=*/false)); - it = it_added.first; - } - *it->second.mutable_element(index) = true; - }; - - // Iterate through values of all constants and entry parameters. These values - // are special because they are held in read-only buffers. If any of these - // values share a buffer with other values (for example, the init value of a - // while is a constant) then copy the value at its definition and replace all - // its uses with the copy. - for (const HloValue* value : alias_analysis->dataflow_analysis().values()) { - if (ValueIsReadOnly(*value) && - alias_analysis->GetBufferContainingValue(*value).values().size() > 1) { - VLOG(2) << "Value " << value->ToShortString() - << " is read only, but its buffer contains more than one value. " - "Copying."; - add_index_to_copy(value->defining_instruction(), value->defining_index()); - } +// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'. +HloInstruction* InstructionCopier::Copy() { + ShapeIndex index; + HloInstruction* copy; + if (ShapeUtil::IsTuple(instruction_->shape())) { + copy = CopyTuple(instruction_, &index); + } else { + copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary( + instruction_->shape(), HloOpcode::kCopy, instruction_)); } - - // Identify copies which must be added at root instructions - for (HloComputation* computation : module->computations()) { - const CallGraphNode& node = call_graph.GetNode(computation); - if (node.context() == CallContext::kParallel) { - continue; - } - TF_RET_CHECK(node.context() == CallContext::kSequential); - - const bool is_entry = computation == module->entry_computation(); - HloInstruction* root = computation->root_instruction(); - - // Mark nondistinct/ambiguous indices. - tensorflow::gtl::FlatSet<const HloBuffer*> seen; - ShapeUtil::ForEachSubshape( - root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) { - std::vector<const HloBuffer*> buffers_at_index = - alias_analysis->ComputeBuffersAt(root, index); - bool buffer_seen_before = false; - for (const HloBuffer* buffer : buffers_at_index) { - buffer_seen_before |= !seen.insert(buffer).second; - } - if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) { - VLOG(2) << "Index " << index << " of root of computation " - << computation->name() << " (" << root->name() - << ") has ambiguous or non-distinct buffer. Copying."; - add_index_to_copy(root, index); - } - }); - - // For entry instructions, mark any parameter or constant values. - if (is_entry) { - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (ValueIsReadOnly(*value)) { - VLOG(2) << "Root of entry computation (" << root->name() - << ") has constant or entry parameter value at index " - << index << ". Copying."; - add_index_to_copy(root, index); - } - } - } - } + for (HloInstruction* user : copy_users_) { + VLOG(2) << "Adding copy between instruction: " << instruction_->name() + << " and user: " << user->name(); + TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy)); } + return copy; +} - // TODO(b/62548313): Buffer assignment uses TuplePointsToAnalysis which is - // computation-scoped. This means the analysis doesn't have visibility to - // constants and entry parameters that cross computation boundaries. This can - // cause invalid buffer assignments so additional conservative copies are - // added to handle these cases. Remove this whole loop when buffer assignment - // uses alias analysis. - for (HloComputation* computation : module->computations()) { - const CallGraphNode& node = call_graph.GetNode(computation); - - bool is_while_body = false; - if (node.context() == CallContext::kSequential && - !node.caller_callsites().empty()) { - CHECK_EQ(node.caller_callsites().size(), 1); - const HloInstruction* calling_instruction = - node.caller_callsites()[0].instruction(); - is_while_body = calling_instruction->opcode() == HloOpcode::kWhile && - calling_instruction->while_body() == node.computation(); +// The 'read_only_indices' are initialized based on points-to analysis on the +// while body corresponding to 'while_hlo'. If the init buffer corresponding to +// a read-only index aliases with a constant, it cannot be considered read-only, +// and must be copied. This is necessary because BufferAssignment does not +// currently assign an allocation for constants (b/32248867). +// This function performs this fix-up of 'read_only_indices'. +// +// Returns a ShapeTree of copy_overrides, which implements an optimization to +// allow multiple while loops that share the same read-only constants to +// share a single copy. +StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants( + const HloInstruction* while_hlo, + const TuplePointsToAnalysis& points_to_analysis, + ShapeTree<bool>* read_only_indices, + FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) { + const HloInstruction* init_hlo = while_hlo->operand(0); + const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo); + + // Mapping from LogicalBuffer to index (used to detect non-distinct indices). + FlatSet<const LogicalBuffer*> buffer_set; + + ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape()); + points_to.ForEachElement([init_hlo, read_only_indices, shared_copies, + &buffer_set, ©_overrides]( + const ShapeIndex& index, + const PointsToSet::BufferList& buffers) { + // Look for read-only entry parameters. + if (!read_only_indices->element(index)) { + return; } - VLOG(2) << computation->name() << " is_while_body: " << is_while_body; - HloInstruction* root = computation->root_instruction(); + for (const LogicalBuffer* buffer : buffers) { + HloInstruction* pointee = buffer->instruction(); + const bool is_constant = pointee->opcode() == HloOpcode::kConstant; + if (!is_constant) { + continue; + } - for (const auto& pair : - alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) { - const ShapeIndex& index = pair.first; - const HloValueSet& value_set = pair.second; - for (const HloValue* value : value_set.values()) { - if (IsConstantValue(*value) && !is_while_body) { - VLOG(2) << "Root of computation (" << root->name() - << ") is constant at index " << index << ". Copying."; - add_index_to_copy(root, index); + // We have found an constant that is read-only in + // the while body. These buffers are managed by the caller, and cannot + // be aliased with HLO buffers. Revert this read-only index, + // to allow it to be copied. + *read_only_indices->mutable_element(index) = false; + + // Optimization to allow multiple while loops that share the same + // read-only entry constants to share a single copy. + // Only unambiguous and distinct array-shaped buffers are allowed, to + // reduce code complexity. The shape of the entry parameter must be + // identical to the shape of the init_hlo at this index, to ensure + // there were no intervening bitcast or GTE instructions, which are + // also hard to handle. + const Shape& pointee_shape = pointee->shape(); + const Shape& init_shape = + ShapeUtil::GetSubshape(init_hlo->shape(), index); + if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) && + ShapeUtil::Equal(pointee_shape, init_shape) && + buffer_set.count(buffer) < 1) { + HloInstruction** copy = &(*shared_copies)[pointee]; + if (*copy == nullptr) { + *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary( + pointee_shape, HloOpcode::kCopy, pointee)); } + // Add the copy as an override. + *copy_overrides.mutable_element(index) = *copy; } - } - } - // Add copy instructions indicated in 'instructions_to_copy' to the module. - for (const auto& pair : instructions_to_copy) { - HloInstruction* instruction = pair.first; - const ShapeTree<bool>& indices_to_copy = pair.second; + // Tracks whether this current buffer is distinct. + buffer_set.insert(buffer); - std::vector<HloInstruction*> users = instruction->users(); - TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy, - instruction->parent()->DeepCopyInstruction( - instruction, &indices_to_copy)); - for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy)); + // We've already reverted the read-only index and handled the + // single-copy optimization above, so there's nothing more to do. + break; } - if (instruction == instruction->parent()->root_instruction()) { - instruction->parent()->set_root_instruction(deep_copy); - } - } - - return Status::OK(); -} - -Status VerifyNoLiveRangeInterference(HloModule* module) { - TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis, - HloAliasAnalysis::Run(module)); - DependencyHloOrdering ordering(module); - TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering)); - return Status::OK(); + }); + return copy_overrides; } -void MaybeDumpModule(const string& message, const HloModule& module) { - if (VLOG_IS_ON(3)) { - VLOG(3) << message; - XLA_VLOG_LINES(3, module.ToString()); - hlo_graph_dumper::MaybeDumpHloModule(module, message); +} // anonymous namespace + +// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the +// base class, since the regular CopyInsertion logic above selectively copies +// tuple elements, while this method assumes all buffers need to be deep copied. +StatusOr<HloInstruction*> CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) { + auto copy_it = inserted_copies_.find(hlo); + if (copy_it == inserted_copies_.end()) { + HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie(); + inserted_copies_.insert({hlo, copy}); + return copy; + } else { + return copy_it->second; } } -} // namespace - StatusOr<bool> CopyInsertion::Run(HloModule* module) { - // Copy insertion is performed in three steps: - // - // (1) Add copies conservatively to guarantee that there is no live-range - // interference. This is done simplistically and usually results in more - // copies than is strictly necessary. - // - // (2) Using a more fine-grained analysis, remove as many copies that were - // added in (1) as possible while ensuring no live-range interference. - // - // (3) Add copies to resolve issues not related to live range interference - // such as parameters and constants live out of the entry computation. - // - // We add copies then remove them (step (1) then (2)) rather than simply - // adding only the copies that are necessary because, in general, it is - // difficult to figure out the minimal set of copies to add once there is - // interference. On the other hand, it is easy to determine if removing a copy - // will introduce interference. - // - // The final copy insertion in (3) is done separately to simplify the - // implementation of copy removal in (2) which is the most complicated part of - // the pass. As is, copy removal only has to reason about live range - // interference. If all copies were added in step (1) then copy removal would - // also have to reason about things like constants and parameters live out of - // the computation. - MaybeDumpModule("before copy insertion", *module); - - std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module); - if (!call_graph->IsFlattened()) { - return FailedPrecondition( - "Call graph must be flattened before copy insertion."); - } + bool changed = false; + VLOG(2) << "CopyInsertion for module " << module->name(); - // Gather Ids of existing kCopy instructions in the module. We avoid removing - // these copies (except via DCE in TupleSimplifier) because they may have been - // added for reasons not considered by copy insertion (eg, layout assignment). - // Instruction id is used instead of HloInstruction* because the pointer - // values may be recycled. - tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies; - for (HloComputation* computation : module->computations()) { + TF_ASSIGN_OR_RETURN( + std::unique_ptr<BufferLiveness> liveness, + BufferLiveness::Run(module, MakeUnique<DependencyHloOrdering>(module))); + const auto& points_to_analysis = liveness->points_to_analysis(); + XLA_VLOG_LINES(2, points_to_analysis.ToString()); + XLA_VLOG_LINES(2, module->ToString()); + + // Gather all while body computations and while instructions. + FlatSet<const HloComputation*> while_body_computations; + std::vector<HloInstruction*> while_instructions; + for (auto* computation : module->computations()) { for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - existing_copies.insert(instruction->unique_id()); + if (instruction->opcode() == HloOpcode::kWhile) { + while_body_computations.insert(instruction->while_body()); + while_instructions.push_back(instruction); } } } - TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module)); - - // Simplify the tuple structures introduced by the deep copies. This should be - // done before removing copies (RemoveUnnecessaryCopies) because tuple - // simplification changes dependencies in the graph which changes live range - // interference in the graph. Also run DCE to remove the dead Tuple/GTE - // instructions introduced by tuple simplification. - TupleSimplifier tuple_simplifier; - HloDCE dce; - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - TF_RETURN_IF_ERROR(dce.Run(module).status()); - - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); - - MaybeDumpModule("after adding copies to resolve interference", *module); - - DependencyHloOrdering ordering(module); - TF_RETURN_IF_ERROR( - RemoveUnnecessaryCopies(ordering, existing_copies, module)); - - MaybeDumpModule("after removing unnecessary copies", *module); - - TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module)); - - MaybeDumpModule("after adding special-case copies", *module); + // Collect instruction buffer indices to copy in 'instructions_to_copy'. + std::vector<InstructionCopier> instructions_to_copy; + + // Add copies of computation root instructions, if needed. + FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices; + for (auto* computation : module->MakeNonfusionComputations()) { + VLOG(2) << "computation " << computation->name(); + InstructionCopier root_copier(computation->root_instruction(), + /*copy_users=*/{}); + if (while_body_computations.count(computation) > 0) { + // Record root indices to copy for while body sub-computations. We do not + // need to call RecordIndicesWhichPointToParamOrConstant for the while + // body root instruction here, because any necessary copies needed to + // avoid constants or parameters in the output are handled by while.init + // operand copy insertion below (which will share an allocation). + HloInstruction* while_body_param = computation->parameter_instruction(0); + ShapeTree<bool> read_only_indices(while_body_param->shape()); + TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_body_param, &read_only_indices)); + while_body_read_only_indices[computation] = read_only_indices; + + // Mark control predecessors, based on the body param, for any copies + // we'll be inserting. This ensures the copy doesn't run too early. + TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors( + points_to_analysis, while_body_param)); + } else { + // Record root indices to copy for general computations. + TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + } + instructions_to_copy.push_back(root_copier); + } - TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status()); - TF_RETURN_IF_ERROR(dce.Run(module).status()); - TF_DCHECK_OK(VerifyNoLiveRangeInterference(module)); + // Add copies of while 'init' operand instructions, if needed. 'shared_copies' + // is used to ensure that multiple while loops can share a single copy of the + // same entry parameter or constant, if all loops use it read-only. + // + // TODO(b/33301720) Remove redundant while instruction copies. + FlatMap<const HloInstruction*, HloInstruction*> shared_copies; + for (HloInstruction* while_hlo : while_instructions) { + // Fix read_only_indices to account for entry constants. Also + // initialize copy_overrides, which ensures a single copy for each read-only + // constant that is used in multiple while loops. + ShapeTree<bool>* read_only_indices = + &while_body_read_only_indices[while_hlo->while_body()]; + TF_ASSIGN_OR_RETURN( + const ShapeTree<HloInstruction*> copy_overrides, + RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis, + read_only_indices, &shared_copies)); + // Create InstructionCopier for init operand of while instruction. + HloInstruction* init_hlo = while_hlo->mutable_operand(0); + InstructionCopier init_copier(init_hlo, {while_hlo}); + init_copier.SetReadOnlyIndices(*read_only_indices); + init_copier.SetCopyOverrides(copy_overrides); + // Record 'init' buffer indices which point-to a Constant or Parameter. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant( + points_to_analysis)); + // Record indices necessary to colocate while and init operand buffers. + TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers( + *liveness, while_hlo, /*read_only_indices_out=*/nullptr)); + instructions_to_copy.push_back(init_copier); + } - MaybeDumpModule("after copy insertion", *module); + for (InstructionCopier& to_copy : instructions_to_copy) { + if (to_copy.HasAllIndicesFalse()) { + continue; + } + changed = true; - if (VLOG_IS_ON(1)) { - int64 num_total_copies = 0; - for (HloComputation* computation : module->computations()) { - for (HloInstruction* instruction : computation->instructions()) { - if (instruction->opcode() == HloOpcode::kCopy) { - num_total_copies++; - } - } + // Copy instruction at recorded buffer indices. + HloComputation* computation = to_copy.instruction()->parent(); + HloInstruction* copy = to_copy.Copy(); + if (to_copy.instruction() == computation->root_instruction()) { + computation->set_root_instruction(copy); } - VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size(); - VLOG(1) << "Num copies after copy-insertion: " << num_total_copies; } - return true; + VLOG(3) << "After copy insertion for module " << module->name(); + XLA_VLOG_LINES(3, module->ToString()); + + return changed; } } // namespace xla |