aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2017-11-03 13:26:47 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-03 13:31:12 -0700
commit456929281592f14d50443cfbdaa2f6b36167a134 (patch)
tree8a1a18245a6ef9baed8bd9a9f35b7c250ab64901 /tensorflow/compiler/xla/service/copy_insertion.cc
parent5b166f495ae79b6e8144bbd3a1109f4b8d9fb1aa (diff)
Rollback copy insertion change because it results in a DCHECK with an internal model.
END_PUBLIC BEGIN_PUBLIC Automated g4 rollback of changelist 174423881 PiperOrigin-RevId: 174505237
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc1526
1 files changed, 504 insertions, 1022 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 8f50b29dad..0453a698a0 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,17 +15,15 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
-#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
+#include <memory>
+
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_dce.h"
-#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/hlo_ordering.h"
#include "tensorflow/compiler/xla/service/liveness_util.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
+#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -33,1113 +31,597 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
-using ::tensorflow::str_util::Join;
-using ::tensorflow::strings::StrAppend;
-using ::tensorflow::strings::StrCat;
-
namespace {
-bool IsEntryParameterValue(const HloValue& value) {
- const HloComputation* computation = value.defining_instruction()->parent();
- return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
- computation == computation->parent()->entry_computation();
-}
-
-bool IsConstantValue(const HloValue& value) {
- return value.defining_instruction()->opcode() == HloOpcode::kConstant;
-}
-
-bool ValueIsReadOnly(const HloValue& value) {
- return IsConstantValue(value) || IsEntryParameterValue(value);
-}
+using tensorflow::gtl::FlatMap;
+using tensorflow::gtl::FlatSet;
-// Deep copy the given instructions 'from' and 'to' at the ShapeIndexes given in
-// 'indices_to_copy'. Add control edges from the respective kCopy instructions
-// in deep copy of 'from' to the respective kCopy instruction in the deep copy
-// of 'to'.
-//
-// Requirements: 'from' and 'to' must have compatible shapes.
+// InstructionCopier encapsulates indices at which to copy 'instruction'.
+// All 'instruction' users in 'copy_users' are updated to use the copy.
//
-// For example, suppose 'from' and 'to' are two-element tuples where index 0 is
-// the only index to copy. Prior to deep-copying we have:
+// Instruction copies are generated in two phases:
+// 1) Recording buffer indices at which 'instruction' requires copies (i.e.
+// setting 'indices_to_copy_[index]'=true).
+// 2) Inserting kCopy instructions based on indices recorded in phase 1).
+// *) Array instructions are copied by inserting a single kCopy instruction.
+// *) Tuple-shaped instructions are copied by recursively expanding tuples
+// (and tuple-shaped elements), and inserting kCopy instructions for any
+// tuple elements which require a copy. As the recursion unwinds, new tuple
+// instructions are added to gather the copied (and uncopied) references
+// into the output tuple (i.e. the copy of the tuple-shaped instruction).
//
+// Example two-element tuple with one element that needs a copy:
//
-// 'from'
-// |
-// ...
-// |
-// 'to'
+// original-instruction
+// / \
+// GTE(0) GTE(1)
+// | |
+// Copy |
+// \ /
+// Tuple // copied-instruction
//
-// DeepCopyAndAddControlEdges produces:
-//
-// 'from'
-// / \
-// GTE GTE
-// | |
-// Copy |
-// / \ /
-// | Tuple
-// | |
-// ctrl ...
-// edge |
-// | |
-// | 'to'
-// | / \
-// | GTE GTE
-// \ | |
-// Copy |
-// \ /
-// Tuple
-//
-StatusOr<std::pair<HloInstruction*, HloInstruction*>>
-DeepCopyAndAddControlEdges(HloInstruction* from, HloInstruction* to,
- const ShapeTree<bool>& indices_to_copy) {
- DCHECK(ShapeUtil::Compatible(from->shape(), to->shape()));
- // to/from_copy_tree hold the kCopy instruction produces by the deep
- // copies. Elements which are not copied (indices_to_copy.element(index) ==
- // false) have nullptr at that index.
- ShapeTree<HloInstruction*> from_copy_tree(from->shape(),
- /*init_value=*/nullptr);
- TF_ASSIGN_OR_RETURN(HloInstruction * from_deep_copy,
- from->parent()->DeepCopyInstruction(
- from, &indices_to_copy, &from_copy_tree));
-
- ShapeTree<HloInstruction*> to_copy_tree(to->shape(), /*init_value=*/nullptr);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * to_deep_copy,
- to->parent()->DeepCopyInstruction(to, &indices_to_copy, &to_copy_tree));
-
- // Add control edges between the respective kCopy instructions.
- for (const auto& pair : from_copy_tree) {
- const ShapeIndex& index = pair.first;
- HloInstruction* from_copy = pair.second;
- HloInstruction* to_copy = to_copy_tree.element(index);
- if (from_copy == nullptr) {
- TF_RET_CHECK(to_copy == nullptr);
- continue;
- }
- TF_RET_CHECK(to_copy != nullptr);
- TF_RETURN_IF_ERROR(from_copy->AddControlDependencyTo(to_copy));
+// As an optimization, if the original instruction is itself a Tuple
+// instruction, we elide the unnecessary extra GTE and Tuple instructions,
+// and just insert the copy into a new Tuple instruction, with control
+// dependencies to ensure the copy occurs after any possible interference.
+class InstructionCopier {
+ public:
+ InstructionCopier(HloInstruction* instruction,
+ const std::vector<HloInstruction*>& copy_users)
+ : instruction_(instruction),
+ copy_users_(copy_users),
+ indices_to_copy_(instruction->shape()),
+ control_predecessors_(instruction->shape()) {}
+
+ // Sets indices that are read-only, and thus do not need to be copied.
+ void SetReadOnlyIndices(const ShapeTree<bool>& read_only_indices) {
+ read_only_indices_ = read_only_indices;
}
- return std::make_pair(from_deep_copy, to_deep_copy);
-}
-
-// Compute the indices of the loop state which need copies in order to avoid
-// live range interference. Generally, an element in the loop state does not
-// need to be copied if the element is passed through transparently through the
-// body.
-//
-// Returns whether any indices need to be copied.
-bool IndicesToCopyForWhile(const HloDataflowAnalysis& dataflow,
- const HloInstruction* xla_while,
- ShapeTree<bool>* indices_to_copy) {
- DCHECK(ShapeUtil::Compatible(indices_to_copy->shape(), xla_while->shape()));
-
- bool any_copies = false;
- const HloInstruction* init = xla_while->operand(0);
- for (auto& pair : *indices_to_copy) {
- const ShapeIndex& index = pair.first;
- bool& should_copy = pair.second;
- // If there is any ambiguity, then loop state must be copied.
- if (dataflow.GetValueSet(init, index).values().size() > 1 ||
- dataflow.GetValueSet(xla_while, index).values().size() > 1) {
- should_copy = true;
- } else {
- // If the output of the while instruction is not the same as the init
- // value of the while, then this element is not passed through the body
- // transparently and must be copied.
- should_copy = dataflow.GetUniqueValueAt(xla_while, index) !=
- dataflow.GetUniqueValueAt(init, index);
- }
- any_copies |= should_copy;
+ // Sets copy overrides, which are copy instructions to use at each index. This
+ // is used to share a single copy of read-only entry parameters and constants
+ // between multiple While loops.
+ void SetCopyOverrides(const ShapeTree<HloInstruction*>& copy_overrides) {
+ copy_overrides_ = copy_overrides;
}
- return any_copies;
-}
-
-// Add kCopy instructions around the given kWhile instruction to eliminate any
-// possible live range interference of HLO values assuming a dependency-based
-// ordering (HloDependencyOrdering). Copies are added conservatively. There
-// likely are copies which are not strictly necessary, but there are removed
-// later in the pass via CopyRemover.
-//
-//
-// Elements (each ShapeIndex) in the loop state are considered independently. A
-// copy is added to each element of the loop state which is modified in the
-// while body. For each such element, a total of three kCopy instructions are
-// added at following locations:
-//
-// (1) The init value is copied before the kWhile instruction. Before:
-//
-// (Init)
-// |
-// kWhile
-// |
-// ...
-//
-// After:
-//
-// (Init)
-// |
-// kCopy
-// |
-// kWhile
-// |
-// ...
-//
-// This copy is necessary in case the init value is simultaneously live
-// with the kWhile.
-//
-// (2) Copies are added to the parameter and root of the while body
-// computation. Before:
-//
-// kParameter
-// |
-// ...
-// |
-// (body root)
-//
-// After:
-//
-// kParameter
-// |
-// kCopy ----------+
-// | |
-// ... ctrl
-// | edge
-// (body root) |
-// | |
-// kCopy <---------+
-//
-// The root kCopy becomes the new root of the computation. Both copies are
-// necessary to any potential interference between the parameter value and
-// the root value. The control edge prevents potential interference
-// between the copies themselves.
-//
-// If the loop state is a tuple then the above kCopy instructions are a deep
-// copy constructed of kCopy, KGetTupleElement, and kTuple instruction as
-// constructed by HloInstruction::DeepCopyInstruction.
-Status AddCopiesForWhile(const HloAliasAnalysis& alias_analysis,
- HloInstruction* xla_while) {
- VLOG(2) << "Adding copies for kWhile instruction " << xla_while->name();
- TF_RET_CHECK(xla_while->opcode() == HloOpcode::kWhile);
- ShapeTree<bool> indices_to_copy(xla_while->shape());
- if (!IndicesToCopyForWhile(alias_analysis.dataflow_analysis(), xla_while,
- &indices_to_copy)) {
- VLOG(2) << "No copies necessary for kWhile instruction "
- << xla_while->name();
- return Status::OK();
- }
+ // Returns true if all recorded indices are false (returns true otherwise).
+ bool HasAllIndicesFalse() const;
- VLOG(2) << "Adding copies for " << xla_while->name() << " at indices:";
- for (auto& pair : indices_to_copy) {
- if (pair.second) {
- VLOG(2) << " " << pair.first;
- }
- }
+ // Records instruction buffer indices which point-to a Parameter or Constant.
+ Status RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis);
- // Deep copy init.
- HloInstruction* while_init = xla_while->mutable_operand(0);
- TF_ASSIGN_OR_RETURN(
- HloInstruction * while_init_copy,
- xla_while->parent()->DeepCopyInstruction(while_init, &indices_to_copy));
- TF_RETURN_IF_ERROR(while_init->ReplaceUseWith(xla_while, while_init_copy));
+ // Records instruction buffer indices to copy which are necessary to ensure:
+ // *) PointsToSet of 'instruction_' is unambiguous and distinct.
+ // *) No liveness interference between 'instruction_' and 'other_instruction'.
+ //
+ // If 'read_only_indices_out' is non-null, read-only indices are set to true.
+ Status RecordIndicesToCopyForColocatingBuffers(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out);
- // Deep copy the parameter and the root. Extend a control edge from the copy
- // of the parameter value to the corresponding copy value of the root.
- HloComputation* body = xla_while->while_body();
- HloInstruction* param = body->parameter_instruction(0);
- HloInstruction* root = body->root_instruction();
+ // Records control predecessors to add for inserted copy instructions.
+ // 'parameter' must have the same shape as the instruction that will be
+ // copied, and must define all buffers in the shape. Control predecessors are
+ // only recorded for indices that have already been marked for copying.
+ Status RecordControlPredecessors(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* parameter);
- // If param is the root then all indices should have been passed through the
- // while body and we should have returned early above.
- TF_RET_CHECK(param != root);
+ // Inserts copies of 'instruction' buffers at indices in 'indices_to_copy',
+ // and replaces all uses for instructions in 'copy_users_' with copy.
+ // Returns the instruction which is a copy 'instruction'.
+ HloInstruction* Copy();
- // Copy users before making a deep copy of the parameter as the deep copy
- // will create new users of the parameter (eg, the GTE instructions of the
- // deep copy).
- std::vector<HloInstruction*> param_users = param->users();
+ HloInstruction* instruction() { return instruction_; }
- ShapeIndex current_index;
- TF_ASSIGN_OR_RETURN(auto pair,
- DeepCopyAndAddControlEdges(param, root, indices_to_copy));
+ const std::vector<HloInstruction*>& copy_users() const { return copy_users_; }
- HloInstruction* param_copy = pair.first;
- HloInstruction* root_copy = pair.second;
+ private:
+ // Does the given index represent a read-only buffer?
+ bool IsReadOnlyIndex(const ShapeIndex& index) const {
+ return !ShapeUtil::IsNil(read_only_indices_.shape()) &&
+ read_only_indices_.element(index);
+ }
- for (HloInstruction* user : param_users) {
- TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, param_copy));
+ // Returns the copy override at the given index, or nullptr.
+ HloInstruction* GetCopyOverride(const ShapeIndex& index) const {
+ return ShapeUtil::IsNil(copy_overrides_.shape())
+ ? nullptr
+ : copy_overrides_.element(index);
}
- body->set_root_instruction(root_copy);
+ // Records instruction buffer indices which have ambiguous or non-distinct
+ // points-to sets.
+ Status RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis);
- return Status::OK();
-}
+ // Records instruction buffer indices which have interfering live ranges
+ // with 'other_instruction' buffers at same index.
+ Status RecordIndicesWhichInterfereWithOtherInstruction(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out);
-// Removes any control dependencies to or from the given instruction.
-Status StripControlDependenciesFrom(HloInstruction* instruction) {
- while (!instruction->control_successors().empty()) {
- TF_RETURN_IF_ERROR(instruction->RemoveControlDependencyTo(
- instruction->control_successors().front()));
- }
+ // Recursively inserts copies of 'instruction' tuple elements at indices
+ // specified in 'indices_to_copy', and returns the copy of 'instruction'.
+ HloInstruction* CopyTuple(HloInstruction* instruction, ShapeIndex* index);
- while (!instruction->control_predecessors().empty()) {
- TF_RETURN_IF_ERROR(
- instruction->control_predecessors().front()->RemoveControlDependencyTo(
- instruction));
+ void RecordIndex(const ShapeIndex& index) {
+ *indices_to_copy_.mutable_element(index) = true;
}
- return Status::OK();
-}
-
-// Add kCopy instructions to the given module to guarantee there is no
-// live-range interference. Generally interference can only occur around kWhile
-// instructions which have update-in-place semantics.
-Status AddCopiesToResolveInterference(HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
+ HloInstruction* instruction_;
+ const std::vector<HloInstruction*> copy_users_;
+ ShapeTree<bool> indices_to_copy_;
+ ShapeTree<std::vector<HloInstruction*>> control_predecessors_;
+ ShapeTree<bool> read_only_indices_;
+ ShapeTree<HloInstruction*> copy_overrides_;
+};
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kWhile) {
- TF_RETURN_IF_ERROR(AddCopiesForWhile(*alias_analysis, instruction));
- }
- }
- }
- return Status::OK();
+bool InstructionCopier::HasAllIndicesFalse() const {
+ bool all_indices_false = true;
+ indices_to_copy_.ForEachElement(
+ [&all_indices_false](const ShapeIndex& /*index*/, bool data) {
+ if (data) {
+ all_indices_false = false;
+ }
+ });
+ return all_indices_false;
}
-// Class for removing unnecessary copies from the module.
-//
-// kCopy instructions are added conservatively to guarantee no live range
-// interference between HLO values. This class uses a more fine-grained analysis
-// to remove some of these added copies which are not strictly necessary.
-class CopyRemover {
- public:
- CopyRemover(const HloAliasAnalysis& alias_analysis,
- const HloOrdering& ordering, HloModule* module)
- : module_(module),
- alias_analysis_(alias_analysis),
- ordering_(ordering),
- buffer_value_tracker_(*module, alias_analysis, ordering) {}
-
- // Try to elide the given copy. The copy is elided if the instruction is not
- // necessary to prevent live-range interference of HLO values. Returns true if
- // copy was elided.
- //
- // The copy instruction is not actually removed here. Instead it is left for
- // dead in the graph. Later calls to DCE will remove the instruction.
- StatusOr<bool> TryElideCopy(HloInstruction* copy) {
- if (buffer_value_tracker_.TryElideCopy(copy)) {
- TF_RETURN_IF_ERROR(StripControlDependenciesFrom(copy));
- TF_RETURN_IF_ERROR(copy->ReplaceAllUsesWith(copy->mutable_operand(0)));
- return true;
- }
- return false;
+Status InstructionCopier::RecordIndicesWhichPointToParamOrConstant(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Shallow copy the instruction if the points-to set of the top-level
+ // buffer is ambiguous. This is necessary because the backends must know
+ // statically what the top-level buffer of the result is.
+ if (points_to.element(/*index=*/{}).size() > 1) {
+ RecordIndex({});
}
- string ToString() const {
- string out = StrCat("CopyRemover, module ", module_->name(), "\n");
- StrAppend(&out, " Buffer values, in dependency order:\n");
- for (const HloBuffer& buffer : alias_analysis_.buffers()) {
- StrAppend(&out, " HloBuffer ", buffer.id(), ":\n");
+ // Multiple buffers within a parameter/constant may be live out, so collect
+ // a set of indices at which to copy first.
+ points_to.ForEachElement([this](const ShapeIndex& index,
+ const PointsToSet::BufferList& buffers) {
+ if (IsReadOnlyIndex(index)) {
+ return;
}
- return out;
- }
-
- private:
- // Class which tracks the HLO values within each HLO buffer in the module
- // during copy removal.
- //
- // The values are held in a linked list where there is one list for each
- // buffer. Removing a copy instruction merges together the values in the
- // source buffer of the copy to the destination buffer of the copy. This class
- // tracks these value lists as copies are removed from the graph (and value
- // lists are merged).
- //
- // The BufferValueTracker object is initialized to match the state of
- // HloAliasAnalysis. However, as copies are removed this state diverges. The
- // values-to-buffer mapping is maintained outside of HloAliasAnalysis because
- // a fully updatable alias analysis is very slow.
- class BufferValueTracker {
- public:
- // The values held in a single HLO buffer are represented using a linked
- // list. An element type in this list is ValueNode.
- //
- // This linked list is hand-rolled to enable efficient splicing of lists
- // using only references to list elements without knowing which lists are
- // being spliced. std::list requires a reference to the list object to
- // splice.
- struct ValueNode {
- explicit ValueNode(const HloValue* v) : value(v) {}
-
- const HloValue* value;
-
- // The uses are maintained outside of HloValue::uses() because
- // HloValue::uses() is not updatable (a fully updatable dataflow analysis
- // is slow).
- std::vector<const HloUse*> uses;
-
- // next/prev elements in the linked list. The list is circularly linked so
- // these values are never null for elements in the list.
- ValueNode* prev = nullptr;
- ValueNode* next = nullptr;
- };
-
- BufferValueTracker(const HloModule& module,
- const HloAliasAnalysis& alias_analysis,
- const HloOrdering& ordering)
- : dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
- // Construct a list for each HLO buffer in the alias analysis. Maintain a
- // map from HloValue to the respective list element representing that
- // value. The map is used to construct the copy info map below.
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
- for (const HloBuffer& buffer : alias_analysis.buffers()) {
- // Verify values contained in the buffer are strictly ordered. This
- // should always be the case after adding copies to eliminate
- // interference. Specifically, the addition of the control flow edges
- // between copies added around aliased operations (kWhile) guarantees
- // this strict order.
- for (const HloValue* value_a : buffer.values()) {
- for (const HloValue* value_b : buffer.values()) {
- if (value_a != value_b) {
- DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
- dataflow_) ||
- ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
- dataflow_))
- << value_a->ToShortString() << " and "
- << value_b->ToShortString() << " are not ordered";
- }
- }
- }
-
- std::vector<const HloValue*> values = buffer.values();
- std::sort(values.begin(), values.end(),
- [this](const HloValue* a, const HloValue* b) {
- return ordering_.IsDefinedBefore(*a, *b);
- });
-
- // Create a list containing all of the values in the buffer.
- AddValueList(values, &value_to_node);
+ for (const LogicalBuffer* buffer : buffers) {
+ // pointee is the HloInstruction producing the buffer which may be
+ // liveout.
+ HloInstruction* pointee = buffer->instruction();
+ if (pointee->opcode() == HloOpcode::kParameter ||
+ pointee->opcode() == HloOpcode::kConstant) {
+ VLOG(2) << "Parameter or constant buffer " << buffer->ToString()
+ << " index: " << tensorflow::str_util::Join(index, ",")
+ << " may be live out of computation: " << pointee->ToString();
+ RecordIndex(index);
+ break;
}
-
- // Create copy_map_ which contains the source and destination values
- // of all copies.
- CreateCopyMap(module, value_to_node);
-
- XLA_VLOG_LINES(3, ToString());
- TF_DCHECK_OK(Verify());
}
+ });
+ return Status::OK();
+}
- // Add a list containing the given values to BufferValueTracker. This
- // represents the values contained in a single buffer. For each value in
- // 'values' an entry is created in value_to_node which indicates the
- // respective ValueNode representing that value.
- void AddValueList(
- tensorflow::gtl::ArraySlice<const HloValue*> values,
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
- ValueNode* tail = nullptr;
- ValueNode* head = nullptr;
- for (const HloValue* value : values) {
- auto new_node = new ValueNode(value);
- (*value_to_node)[value] = new_node;
-
- // Copy the HLO values's uses into the ValueNode for the value. These
- // uses in ValueNode are updated as copies are removed.
- new_node->uses.reserve(value->uses().size());
- for (const HloUse& use : value->uses()) {
- new_node->uses.push_back(&use);
- }
-
- // Connect the new node into the linked list.
- if (tail == nullptr) {
- head = new_node;
- } else {
- tail->next = new_node;
- new_node->prev = tail;
- }
- tail = new_node;
- }
-
- // The linked list is circular so connect the head and tail.
- tail->next = head;
- head->prev = tail;
- value_lists_.insert(head);
- }
+Status InstructionCopier::RecordIndicesToCopyForColocatingBuffers(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out) {
+ TF_RETURN_IF_ERROR(
+ RecordAmbiguousOrNonDistinctIndices(liveness.points_to_analysis()));
+ TF_RETURN_IF_ERROR(RecordIndicesWhichInterfereWithOtherInstruction(
+ liveness, other_instruction, read_only_indices_out));
+ return Status::OK();
+}
- // This method also fills in copy_map_ which indicates which nodes
- // in the value lists corresponding to the source and destination values of
- // kCopy instructions. value_to_node should map each HloValue to its
- // respective ValueNode.
- void CreateCopyMap(
- const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
- value_to_node) {
- for (HloComputation* computation : module.computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- // Add copies with unambiguous source values to the map. Copies with
- // ambiguous sources are not removable.
- if (instruction->opcode() == HloOpcode::kCopy) {
- const HloValueSet& src_value_set =
- dataflow_.GetValueSet(instruction->operand(0));
- if (src_value_set.values().size() == 1) {
- CopyNodes& copy_node = copy_map_[instruction];
- copy_node.dest =
- value_to_node.at(&dataflow_.GetUniqueValueAt(instruction));
- copy_node.src = value_to_node.at(&src_value_set.GetUniqueValue());
- }
+Status InstructionCopier::RecordAmbiguousOrNonDistinctIndices(
+ const TuplePointsToAnalysis& points_to_analysis) {
+ const PointsToSet& points_to =
+ points_to_analysis.GetPointsToSet(instruction_);
+ // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
+ FlatMap<const LogicalBuffer*, std::vector<ShapeIndex>>
+ buffer_to_source_indices;
+ points_to.ForEachElement(
+ [this, &buffer_to_source_indices](
+ const ShapeIndex& index, const PointsToSet::BufferList& buffers) {
+ if (buffers.size() > 1) {
+ // Record ambiguous points-to set at 'index'.
+ if (!indices_to_copy_.element(index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " with ambiguous points-to set.";
+ RecordIndex(index);
}
}
- }
- }
+ // For each 'buffer': record a mapping from 'buffer' to 'index'.
+ for (const LogicalBuffer* buffer : buffers) {
+ buffer_to_source_indices[buffer].push_back(index);
+ }
+ });
- ~BufferValueTracker() {
- for (const ValueNode* head : value_lists_) {
- const ValueNode* p = head;
- do {
- const ValueNode* tmp = p->next;
- delete p;
- p = tmp;
- } while (p != head);
- }
+ // Record all non-distinct indices detected in 'buffer_to_source_indices'.
+ for (const auto& buff_to_src : buffer_to_source_indices) {
+ if (buff_to_src.second.size() == 1) {
+ continue;
}
-
- // Verify invariants within the linked lists.
- Status Verify() const {
- for (const ValueNode* head : value_lists_) {
- const ValueNode* p = head;
- do {
- // Verify links between elements are consistent.
- TF_RET_CHECK(p->prev->next == p);
- TF_RET_CHECK(p->next->prev == p);
-
- const HloInstruction* def = p->value->defining_instruction();
- if (def->opcode() == HloOpcode::kCopy &&
- ContainsKey(copy_map_, def)) {
- TF_RET_CHECK(copy_map_.at(def).dest == p);
- }
- for (const HloUse* use : p->uses) {
- if (use->instruction->opcode() == HloOpcode::kCopy &&
- ContainsKey(copy_map_, use->instruction)) {
- TF_RET_CHECK(copy_map_.at(use->instruction).src == p);
- }
- }
-
- p = p->next;
- } while (p != head);
+ for (const ShapeIndex& src_index : buff_to_src.second) {
+ // Record non-distinct points-to set at 'src_index'.
+ if (!indices_to_copy_.element(src_index)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " at index: " << tensorflow::str_util::Join(src_index, ",")
+ << " because of non-distinct points-to set.";
+ RecordIndex(src_index);
}
- return Status::OK();
}
+ }
+ return Status::OK();
+}
- // Try to elide the given copy. Elision of a copy is possible only if no
- // live range interference is introduced by the copy's elimination. If
- // elision is possible, then the internal state (value lists) are updated,
- // and true is returned. Returns false otherwise.
- bool TryElideCopy(const HloInstruction* copy) {
- VLOG(2) << "Trying to remove " << copy->name();
-
- if (!ContainsKey(copy_map_, copy)) {
- VLOG(2) << copy->name() << " is not removable";
- return false;
- }
-
- const CopyNodes& copy_node = copy_map_.at(copy);
- ValueNode* src = copy_node.src;
- ValueNode* dest = copy_node.dest;
- DCHECK(src != nullptr);
- DCHECK(dest != nullptr);
-
- auto is_live_range_before = [this](const ValueNode& a,
- const ValueNode& b) {
- if (LiveRangeBefore(a, b)) {
- VLOG(2) << " Live range of " << a.value->ToShortString()
- << " is before " << b.value->ToShortString();
- return true;
- } else {
- VLOG(2) << " Live range of " << a.value->ToShortString()
- << " is not before " << b.value->ToShortString();
- return false;
+Status InstructionCopier::RecordIndicesWhichInterfereWithOtherInstruction(
+ const BufferLiveness& liveness, const HloInstruction* other_instruction,
+ ShapeTree<bool>* read_only_indices_out) {
+ // Record all buffer indices for 'instruction_', which interfere with
+ // 'other_instruction' at the same index.
+ ShapeUtil::ForEachSubshape(
+ instruction_->shape(),
+ [this, &liveness, other_instruction, read_only_indices_out](
+ const Shape& /*subshape*/, const ShapeIndex& index) {
+ if (IsReadOnlyIndex(index)) {
+ return;
}
- };
-
- // A kCopy instruction copies an HLO value from a source buffer and
- // defines an HLO value in a destination buffer. Most generally, the
- // source and destination buffers may each hold more than one value at
- // different points in the computation so we define the following:
- //
- // Values in source buffer: {s_0, ..., s_n}
- // Values in destination buffer: {d_0, ..., d_m}
- //
- // A kCopy instruction between these buffers copies a value s_x in the
- // source buffer and defines a value d_y in the destination buffer. The
- // elision of a copy merges the source and destination buffers together,
- // so the list of values for the source and destination buffers are
- // merged.
- //
- // We handle two different cases for copy elision:
- //
- // (1) the kCopy defines the first value in the destination buffer (d_0).
- //
- // (2) the kCopy copies the last value in the source buffer (s_n).
- //
- // For the remaining case where the kCopy copies a not-last value from the
- // source buffer to a not-first value of the destination buffer, the kCopy
- // instruction cannot be removed. This case is generated, for example, if
- // the kCopy copies a while body parameter of the loop state at one tuple
- // index to a different tuple index in the while body root. Removal of the
- // copy necessarily results in live range interference of values in the
- // loop state at the two different tuple indices.
- //
- // We can only perform copy elision if the resulting merged values have
- // totally ordered live ranges; otherwise the merged buffer would have
- // live range interference.
- if (IsHead(*dest)) {
- // The copy copies an arbitrary value in the source buffer (call it s_x)
- // and defines d_0, the first value in the destination buffer. After
- // merging, the values in the combined buffer must be strictly ordered
- // as follows** to elide the copy:
- //
- // {s_0, ..., s_x, d_1, ..., d_m, s_{x+1}, ..., s_n}
- //
- // Removing the copy eliminates d_0, and uses of d_0 become uses of
- // s_x. In the above ordering, the live range of d_m must be ordered
- // before the live range of s_{x+1} and the definition and all uses of
- // s_x must be ordered before the definition of d_1. These conditions
- // are checked below prior to elision.
- //
- // ** Technically it might be possible to have a non-interfering
- // non-trivial interleaving of the values of the source and
- // destination buffers in the resulting order. However, this case is
- // slow and complicated to check and likely not worth it. So instead
- // we simply check for the case where *all* values of the destination
- // buffer (d_1 through d_m) are spliced into the point where the copy
- // used to be.
- VLOG(2) << copy->name() << " defines the first value in its buffer";
- ValueNode* next_dest = Next(*dest);
- if (next_dest != nullptr) {
- // Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
- if (!is_live_range_before(*src, *next_dest)) {
- return false;
- }
+ if (indices_to_copy_.element(index)) {
+ // Return if previous pass already set index.
+ return;
}
- ValueNode* next_src = Next(*src);
-
- if (next_src != nullptr) {
- // Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
- ValueNode* last_dest = dest->prev;
- DCHECK(IsTail(*last_dest));
- if (!is_live_range_before(*last_dest, *next_src)) {
- return false;
+ const auto& points_to_analysis = liveness.points_to_analysis();
+ // Lookup buffers for 'instruction_' and 'other_instruction'.
+ const auto instruction_buffers =
+ points_to_analysis.GetPointsToSet(instruction_).element(index);
+ // If 'instruction_' has ambiguous points-to-set at 'index', it would
+ // have been recorded in a previous pass (and we would have returned
+ // early at the entry to this function). As a result, here we know that
+ // 'instruction_' has just one buffer in its points-to-set.
+ CHECK_EQ(1, instruction_buffers.size());
+ const LogicalBuffer* instruction_buffer = instruction_buffers[0];
+
+ const auto other_instruction_buffers =
+ points_to_analysis.GetPointsToSet(other_instruction).element(index);
+ // Do not insert a copy if both instructions point at the same buffer.
+ // This eliminates unnecessary copies of read-only tuple elements.
+ // If 'instruction_' and 'other_instruction' point to the same buffer,
+ // then that buffer is not updated on the path between the two
+ // instructions. Therefore, any other (possibly interference-causing)
+ // users of that buffer from 'other_instruction' will see the same data,
+ // irrespective of whether we insert a copy of this buffer at
+ // 'instruction_' or not.
+ if (other_instruction_buffers.size() == 1 &&
+ other_instruction_buffers[0]->id() == instruction_buffer->id()) {
+ if (read_only_indices_out != nullptr) {
+ *read_only_indices_out->mutable_element(index) = true;
}
+ return;
}
-
- // Splice in destination buffer values list right after 'src'.
- SpliceAfter(dest, src);
- } else if (IsTail(*src)) {
- // The copy copies the last value in the source buffer, s_n, and defines
- // an arbitrary value in the destination buffer, d_y. After
- // merging, the values in the combined buffer must be strictly ordered
- // as follows** to elide the copy:
- //
- // {d_0, ..., d_{y-1}, s_0, ..., s_n, d_{y+1}, ..., d_m}
- //
- // Removing the copy eliminates d_y, and uses of d_y become uses of
- // s_n. To enforce the above order, the live range of d_{y-1} must be
- // before the live range of s_0, and the live range of s_n must be
- // before the live range of d_{y+1}.
- //
- // ** See comment above in the code handling Case (1).
- VLOG(2) << copy->name() << " copies the last value ("
- << src->value->ToShortString() << ") in its buffer";
-
- ValueNode* prev_dest = Prev(*dest);
- // nullptr condition handled above in the first 'if' case.
- DCHECK(prev_dest != nullptr);
- ValueNode* first_src = src->next;
- DCHECK(IsHead(*first_src));
- if (!is_live_range_before(*prev_dest, *first_src)) {
- // Live range of value d_{y-1} is not before s_0.
- return false;
- }
- ValueNode* next_dest = Next(*dest);
- if (next_dest != nullptr) {
- if (!is_live_range_before(*src, *next_dest)) {
- // Live range of value s_n is not before d_{y+1}.
- return false;
+ // We can't say anything about the ambiguity of 'other_instruction' at
+ // this point, so we need to check interference between the single
+ // buffer in the points-to set of 'instruction_' and all buffers in
+ // 'other_instruction_buffers'.
+ for (const LogicalBuffer* other_buffer : other_instruction_buffers) {
+ if (liveness.MayInterfere(*instruction_buffer, *other_buffer)) {
+ VLOG(2) << "Adding copy of buffer for instruction: "
+ << instruction_->name()
+ << " instruction_buffer: " << instruction_buffer->ToString()
+ << " at index: " << tensorflow::str_util::Join(index, ",")
+ << " because of interference with buffer: "
+ << other_buffer->ToString();
+ RecordIndex(index);
+ break;
}
}
+ });
+ return Status::OK();
+}
- // Splice source buffer values list right after 'prev_dest'.
- SpliceAfter(first_src, prev_dest);
- } else {
- VLOG(2)
- << copy->name()
- << " copies value in middle of source buffer to value in middle "
- "of destination buffer";
- return false;
- }
-
- RemoveCopyValue(dest);
-
- XLA_VLOG_LINES(4, ToString());
- TF_DCHECK_OK(Verify());
-
- return true;
- }
-
- // Delete the given ValueNode associated with a elided kCopy
- // instruction. This should be called after splicing the value lists of the
- // source and destination buffers together.
- void RemoveCopyValue(ValueNode* copy_value_node) {
- CHECK_EQ(copy_value_node->value->defining_instruction()->opcode(),
- HloOpcode::kCopy);
- ValueNode* operand_node = copy_value_node->prev;
- CHECK(operand_node != copy_value_node);
-
- VLOG(2) << "Removing copy " << operand_node->value->ToShortString()
- << " => " << copy_value_node->value->ToShortString();
-
- // Splice out the copy value node.
- operand_node->next = copy_value_node->next;
- copy_value_node->next->prev = operand_node;
-
- // Patch up uses. Remove use of copy from operand_node uses.
- auto it =
- std::find_if(operand_node->uses.begin(), operand_node->uses.end(),
- [copy_value_node](const HloUse* use) {
- return use->instruction ==
- copy_value_node->value->defining_instruction();
- });
- CHECK(it != operand_node->uses.end());
- operand_node->uses.erase(it);
-
- // If the elided copy has any uses which are themselves kCopy instructions
- // then patch up the copy info to reflect the that this kCopy instruction
- // has a different operand (the operand of the elided copy).
- for (const HloUse* copy_use : copy_value_node->uses) {
- operand_node->uses.push_back(copy_use);
- if (copy_use->instruction->opcode() == HloOpcode::kCopy) {
- copy_map_.at(copy_use->instruction).src = operand_node;
- }
- }
-
- // Delete the copy info and the value node.
- copy_map_.erase(copy_value_node->value->defining_instruction());
- delete copy_value_node;
- }
-
- // Returns true if the live range of given value 'a' is before the live
- // range of 'b'.
- //
- // We cannot use LiveRangeStrictlyBefore because HloValue::uses() is not
- // updated as copies are removed.
- bool LiveRangeBefore(const ValueNode& a, const ValueNode& b) {
- if (a.uses.empty()) {
- VLOG(2) << "Empty uses";
- return ordering_.IsDefinedBefore(*a.value, *b.value);
- }
- for (const HloUse* use : a.uses) {
- VLOG(2) << "use: " << *use;
- VLOG(2) << "is before:" << *b.value;
- if (!ordering_.UseIsBeforeValueDefinition(*use, *b.value, dataflow_)) {
- VLOG(2) << "Not before";
- return false;
+// This is called when 'instruction_' is a while body root, and 'parameter' is
+// the while body parameter. We record all users of all aliases of 'parameter'
+// as control predecessors, so that when we add a copy of 'instruction_', we can
+// mark the control dependencies. This is necessary because points-to and
+// liveness analysis doesn't know about the aliasing between the while body root
+// and param. Without these control dependencies, the copy might get scheduled
+// to run at a point that interferes with users of the buffer.
+Status InstructionCopier::RecordControlPredecessors(
+ const TuplePointsToAnalysis& points_to_analysis,
+ HloInstruction* parameter) {
+ return indices_to_copy_.ForEachElementWithStatus(
+ [this, &points_to_analysis, parameter](const ShapeIndex& index,
+ bool will_copy) {
+ if (will_copy) {
+ TF_ASSIGN_OR_RETURN(
+ const LogicalBuffer* buffer,
+ points_to_analysis.GetBufferDefinedAt(parameter, index));
+ for (const BufferAlias& alias :
+ points_to_analysis.GetBufferAliases(*buffer)) {
+ for (HloInstruction* user : alias.instruction()->users()) {
+ if (DoesNotUseOperandBuffer(alias.instruction(), alias.index(),
+ user, points_to_analysis)) {
+ continue;
+ }
+
+ if (user != instruction_) {
+ control_predecessors_.mutable_element(index)->push_back(user);
+ }
+ }
+ }
}
- }
- return true;
- }
-
- // Returns whether 'node' is the last node in its list.
- bool IsTail(const ValueNode& node) const {
- return ContainsKey(value_lists_, node.next);
- }
-
- // Returns whether 'node' is the first node in its list.
- bool IsHead(const ValueNode& node) const {
- return ContainsKey(value_lists_, &node);
- }
-
- // Returns the next node in the list after 'node'. If 'node' is the
- // tail, then nullptr is returned.
- ValueNode* Next(const ValueNode& node) const {
- if (IsTail(node)) {
- return nullptr;
- } else {
- return node.next;
- }
- }
-
- // Returns the previous node in the list before 'node'. If 'node'
- // is the head, then nullptr is returned.
- ValueNode* Prev(const ValueNode& node) const {
- if (IsHead(node)) {
- return nullptr;
- } else {
- return node.prev;
- }
- }
-
- // Splices the entire linked list with 'head' as its head right after the
- // node 'insert_after' in another linked list.
- void SpliceAfter(ValueNode* head, ValueNode* insert_after) {
- DCHECK(IsHead(*head));
- value_lists_.erase(head);
-
- ValueNode* tail = head->prev;
- tail->next = insert_after->next;
- insert_after->next->prev = tail;
-
- insert_after->next = head;
- head->prev = insert_after;
- }
-
- string ToString() const {
- string out = StrCat("BufferValueTracker:\n");
- StrAppend(&out, " Def-use chains in each buffer:\n");
- for (const ValueNode* head : value_lists_) {
- StrAppend(&out, " Buffer defined by ", head->value->ToShortString(),
- ":\n");
- const ValueNode* p = head;
- do {
- StrAppend(&out, " ", p->value->ToShortString(), ", uses: ",
- Join(p->uses, "; ",
- [](string* s, const HloUse* use) {
- StrAppend(s, use->ToString());
- }),
- "\n");
-
- p = p->next;
- } while (p != head);
- }
- StrAppend(&out, " Potentially removable copies:\n");
- for (const auto& pair : copy_map_) {
- const HloInstruction* copy = pair.first;
- const CopyNodes& copy_info = pair.second;
+ return Status::OK();
+ });
+}
- StrAppend(&out, " ", copy->name(), " : ",
- copy_info.src->value->ToShortString(), " => ",
- copy_info.dest->value->ToShortString(), "\n");
- }
- return out;
+// Recursively inserts copies of 'instruction' tuple element buffers at
+// indices in 'indices_to_copy_', expanding tuples as needed.
+HloInstruction* InstructionCopier::CopyTuple(HloInstruction* instruction,
+ ShapeIndex* index) {
+ const int64 num_tuple_elements =
+ ShapeUtil::TupleElementCount(instruction->shape());
+ std::vector<HloInstruction*> elem_copies(num_tuple_elements);
+ for (int64 i = 0; i < num_tuple_elements; ++i) {
+ HloInstruction* elem;
+ if (instruction->opcode() == HloOpcode::kTuple) {
+ // If the instruction is already a Tuple instruction, we know that the
+ // element buffers are aliased, so we can just grab the operand directly.
+ elem = instruction->mutable_operand(i);
+ } else {
+ // Otherwise we need to add a GTE to unpack the element out of the tuple.
+ elem = instruction->parent()->AddInstruction(
+ HloInstruction::CreateGetTupleElement(
+ ShapeUtil::GetSubshape(instruction->shape(), {i}), instruction,
+ i));
}
-
- private:
- const HloDataflowAnalysis& dataflow_;
- const HloOrdering& ordering_;
-
- // The heads of all the value lists. Each value list represents the HLO
- // values contained in a particular HLO buffer. The values in the list are
- // in dependency order.
- tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
-
- // Copy removal requires fast access to the value list elements
- // corresponding to the source and destination values of the kCopy
- // instruction. This data structure holds pointers to these elements for
- // each kCopy instruction in the graph.
- struct CopyNodes {
- // The source and destinations values of the kCopy instruction.
- ValueNode* src = nullptr;
- ValueNode* dest = nullptr;
- };
- tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
- };
-
- HloModule* module_;
- const HloAliasAnalysis& alias_analysis_;
- const HloOrdering& ordering_;
-
- // Object tracking the HLO values contained in each HLO buffer.
- BufferValueTracker buffer_value_tracker_;
-};
-
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<HloInstruction::Id>& copies_to_exclude,
- HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ index->push_back(i);
+ if (ShapeUtil::IsTuple(elem->shape())) {
+ elem_copies[i] = CopyTuple(elem, index);
+ } else if (!indices_to_copy_.element(*index)) {
+ elem_copies[i] = elem;
+ } else if (HloInstruction* copy_override = GetCopyOverride(*index)) {
+ elem_copies[i] = copy_override;
+ } else {
+ HloInstruction* elem_copy = elem->parent()->AddInstruction(
+ HloInstruction::CreateUnary(elem->shape(), HloOpcode::kCopy, elem));
+ for (HloInstruction* control_predecessor :
+ control_predecessors_.element(*index)) {
+ VLOG(2) << "Adding control dependency from "
+ << control_predecessor->ToString() << " to "
+ << elem_copy->ToString();
+ TF_CHECK_OK(control_predecessor->AddControlDependencyTo(elem_copy));
}
+ elem_copies[i] = elem_copy;
}
+ index->pop_back();
}
-
- return Status::OK();
+ return instruction->parent()->AddInstruction(
+ HloInstruction::CreateTuple(elem_copies));
}
-// Add copies to address special constraints on the roots of computations not
-// related to live range interference:
-//
-// (1) Entry computation root must be unambiguous and distinct.
-//
-// (2) Any computation called by a kCall instruction must have an
-// unambiguous root.
-//
-// (3) Constants and parameters cannot be live out of the entry computation
-//
-Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
-
- // Identify which shape indices of which instructions need to be copied. Store
- // these results in 'instructions_to_copy'.
- std::unordered_map<HloInstruction*, ShapeTree<bool>> instructions_to_copy;
- auto add_index_to_copy = [&instructions_to_copy](HloInstruction* instruction,
- const ShapeIndex& index) {
- auto it = instructions_to_copy.find(instruction);
- if (it == instructions_to_copy.end()) {
- auto it_added = instructions_to_copy.emplace(
- std::piecewise_construct, std::forward_as_tuple(instruction),
- std::forward_as_tuple(instruction->shape(), /*init_value=*/false));
- it = it_added.first;
- }
- *it->second.mutable_element(index) = true;
- };
-
- // Iterate through values of all constants and entry parameters. These values
- // are special because they are held in read-only buffers. If any of these
- // values share a buffer with other values (for example, the init value of a
- // while is a constant) then copy the value at its definition and replace all
- // its uses with the copy.
- for (const HloValue* value : alias_analysis->dataflow_analysis().values()) {
- if (ValueIsReadOnly(*value) &&
- alias_analysis->GetBufferContainingValue(*value).values().size() > 1) {
- VLOG(2) << "Value " << value->ToShortString()
- << " is read only, but its buffer contains more than one value. "
- "Copying.";
- add_index_to_copy(value->defining_instruction(), value->defining_index());
- }
+// Inserts copies of 'instruction_' buffers at indices in 'indices_to_copy_'.
+HloInstruction* InstructionCopier::Copy() {
+ ShapeIndex index;
+ HloInstruction* copy;
+ if (ShapeUtil::IsTuple(instruction_->shape())) {
+ copy = CopyTuple(instruction_, &index);
+ } else {
+ copy = instruction_->parent()->AddInstruction(HloInstruction::CreateUnary(
+ instruction_->shape(), HloOpcode::kCopy, instruction_));
}
-
- // Identify copies which must be added at root instructions
- for (HloComputation* computation : module->computations()) {
- const CallGraphNode& node = call_graph.GetNode(computation);
- if (node.context() == CallContext::kParallel) {
- continue;
- }
- TF_RET_CHECK(node.context() == CallContext::kSequential);
-
- const bool is_entry = computation == module->entry_computation();
- HloInstruction* root = computation->root_instruction();
-
- // Mark nondistinct/ambiguous indices.
- tensorflow::gtl::FlatSet<const HloBuffer*> seen;
- ShapeUtil::ForEachSubshape(
- root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
- std::vector<const HloBuffer*> buffers_at_index =
- alias_analysis->ComputeBuffersAt(root, index);
- bool buffer_seen_before = false;
- for (const HloBuffer* buffer : buffers_at_index) {
- buffer_seen_before |= !seen.insert(buffer).second;
- }
- if (buffers_at_index.size() > 1 || (buffer_seen_before && is_entry)) {
- VLOG(2) << "Index " << index << " of root of computation "
- << computation->name() << " (" << root->name()
- << ") has ambiguous or non-distinct buffer. Copying.";
- add_index_to_copy(root, index);
- }
- });
-
- // For entry instructions, mark any parameter or constant values.
- if (is_entry) {
- for (const auto& pair :
- alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
- const ShapeIndex& index = pair.first;
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (ValueIsReadOnly(*value)) {
- VLOG(2) << "Root of entry computation (" << root->name()
- << ") has constant or entry parameter value at index "
- << index << ". Copying.";
- add_index_to_copy(root, index);
- }
- }
- }
- }
+ for (HloInstruction* user : copy_users_) {
+ VLOG(2) << "Adding copy between instruction: " << instruction_->name()
+ << " and user: " << user->name();
+ TF_CHECK_OK(instruction_->ReplaceUseWith(user, copy));
}
+ return copy;
+}
- // TODO(b/62548313): Buffer assignment uses TuplePointsToAnalysis which is
- // computation-scoped. This means the analysis doesn't have visibility to
- // constants and entry parameters that cross computation boundaries. This can
- // cause invalid buffer assignments so additional conservative copies are
- // added to handle these cases. Remove this whole loop when buffer assignment
- // uses alias analysis.
- for (HloComputation* computation : module->computations()) {
- const CallGraphNode& node = call_graph.GetNode(computation);
-
- bool is_while_body = false;
- if (node.context() == CallContext::kSequential &&
- !node.caller_callsites().empty()) {
- CHECK_EQ(node.caller_callsites().size(), 1);
- const HloInstruction* calling_instruction =
- node.caller_callsites()[0].instruction();
- is_while_body = calling_instruction->opcode() == HloOpcode::kWhile &&
- calling_instruction->while_body() == node.computation();
+// The 'read_only_indices' are initialized based on points-to analysis on the
+// while body corresponding to 'while_hlo'. If the init buffer corresponding to
+// a read-only index aliases with a constant, it cannot be considered read-only,
+// and must be copied. This is necessary because BufferAssignment does not
+// currently assign an allocation for constants (b/32248867).
+// This function performs this fix-up of 'read_only_indices'.
+//
+// Returns a ShapeTree of copy_overrides, which implements an optimization to
+// allow multiple while loops that share the same read-only constants to
+// share a single copy.
+StatusOr<ShapeTree<HloInstruction*>> RevertReadOnlyIndicesForConstants(
+ const HloInstruction* while_hlo,
+ const TuplePointsToAnalysis& points_to_analysis,
+ ShapeTree<bool>* read_only_indices,
+ FlatMap<const HloInstruction*, HloInstruction*>* shared_copies) {
+ const HloInstruction* init_hlo = while_hlo->operand(0);
+ const PointsToSet& points_to = points_to_analysis.GetPointsToSet(init_hlo);
+
+ // Mapping from LogicalBuffer to index (used to detect non-distinct indices).
+ FlatSet<const LogicalBuffer*> buffer_set;
+
+ ShapeTree<HloInstruction*> copy_overrides(init_hlo->shape());
+ points_to.ForEachElement([init_hlo, read_only_indices, shared_copies,
+ &buffer_set, &copy_overrides](
+ const ShapeIndex& index,
+ const PointsToSet::BufferList& buffers) {
+ // Look for read-only entry parameters.
+ if (!read_only_indices->element(index)) {
+ return;
}
- VLOG(2) << computation->name() << " is_while_body: " << is_while_body;
- HloInstruction* root = computation->root_instruction();
+ for (const LogicalBuffer* buffer : buffers) {
+ HloInstruction* pointee = buffer->instruction();
+ const bool is_constant = pointee->opcode() == HloOpcode::kConstant;
+ if (!is_constant) {
+ continue;
+ }
- for (const auto& pair :
- alias_analysis->dataflow_analysis().GetInstructionValueSet(root)) {
- const ShapeIndex& index = pair.first;
- const HloValueSet& value_set = pair.second;
- for (const HloValue* value : value_set.values()) {
- if (IsConstantValue(*value) && !is_while_body) {
- VLOG(2) << "Root of computation (" << root->name()
- << ") is constant at index " << index << ". Copying.";
- add_index_to_copy(root, index);
+ // We have found an constant that is read-only in
+ // the while body. These buffers are managed by the caller, and cannot
+ // be aliased with HLO buffers. Revert this read-only index,
+ // to allow it to be copied.
+ *read_only_indices->mutable_element(index) = false;
+
+ // Optimization to allow multiple while loops that share the same
+ // read-only entry constants to share a single copy.
+ // Only unambiguous and distinct array-shaped buffers are allowed, to
+ // reduce code complexity. The shape of the entry parameter must be
+ // identical to the shape of the init_hlo at this index, to ensure
+ // there were no intervening bitcast or GTE instructions, which are
+ // also hard to handle.
+ const Shape& pointee_shape = pointee->shape();
+ const Shape& init_shape =
+ ShapeUtil::GetSubshape(init_hlo->shape(), index);
+ if (buffers.size() == 1 && ShapeUtil::IsArray(pointee_shape) &&
+ ShapeUtil::Equal(pointee_shape, init_shape) &&
+ buffer_set.count(buffer) < 1) {
+ HloInstruction** copy = &(*shared_copies)[pointee];
+ if (*copy == nullptr) {
+ *copy = pointee->parent()->AddInstruction(HloInstruction::CreateUnary(
+ pointee_shape, HloOpcode::kCopy, pointee));
}
+ // Add the copy as an override.
+ *copy_overrides.mutable_element(index) = *copy;
}
- }
- }
- // Add copy instructions indicated in 'instructions_to_copy' to the module.
- for (const auto& pair : instructions_to_copy) {
- HloInstruction* instruction = pair.first;
- const ShapeTree<bool>& indices_to_copy = pair.second;
+ // Tracks whether this current buffer is distinct.
+ buffer_set.insert(buffer);
- std::vector<HloInstruction*> users = instruction->users();
- TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
- instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
- for (HloInstruction* user : users) {
- TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
+ // We've already reverted the read-only index and handled the
+ // single-copy optimization above, so there's nothing more to do.
+ break;
}
- if (instruction == instruction->parent()->root_instruction()) {
- instruction->parent()->set_root_instruction(deep_copy);
- }
- }
-
- return Status::OK();
-}
-
-Status VerifyNoLiveRangeInterference(HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- DependencyHloOrdering ordering(module);
- TF_RET_CHECK(!alias_analysis->HasLiveRangeInterference(ordering));
- return Status::OK();
+ });
+ return copy_overrides;
}
-void MaybeDumpModule(const string& message, const HloModule& module) {
- if (VLOG_IS_ON(3)) {
- VLOG(3) << message;
- XLA_VLOG_LINES(3, module.ToString());
- hlo_graph_dumper::MaybeDumpHloModule(module, message);
+} // anonymous namespace
+
+// NOTE: This is only called by gpu::CopyInsertion. It's not called here in the
+// base class, since the regular CopyInsertion logic above selectively copies
+// tuple elements, while this method assumes all buffers need to be deep copied.
+StatusOr<HloInstruction*> CopyInsertion::FindOrInsertCopy(HloInstruction* hlo) {
+ auto copy_it = inserted_copies_.find(hlo);
+ if (copy_it == inserted_copies_.end()) {
+ HloInstruction* copy = hlo->parent()->DeepCopyInstruction(hlo).ValueOrDie();
+ inserted_copies_.insert({hlo, copy});
+ return copy;
+ } else {
+ return copy_it->second;
}
}
-} // namespace
-
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
- // Copy insertion is performed in three steps:
- //
- // (1) Add copies conservatively to guarantee that there is no live-range
- // interference. This is done simplistically and usually results in more
- // copies than is strictly necessary.
- //
- // (2) Using a more fine-grained analysis, remove as many copies that were
- // added in (1) as possible while ensuring no live-range interference.
- //
- // (3) Add copies to resolve issues not related to live range interference
- // such as parameters and constants live out of the entry computation.
- //
- // We add copies then remove them (step (1) then (2)) rather than simply
- // adding only the copies that are necessary because, in general, it is
- // difficult to figure out the minimal set of copies to add once there is
- // interference. On the other hand, it is easy to determine if removing a copy
- // will introduce interference.
- //
- // The final copy insertion in (3) is done separately to simplify the
- // implementation of copy removal in (2) which is the most complicated part of
- // the pass. As is, copy removal only has to reason about live range
- // interference. If all copies were added in step (1) then copy removal would
- // also have to reason about things like constants and parameters live out of
- // the computation.
- MaybeDumpModule("before copy insertion", *module);
-
- std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
- if (!call_graph->IsFlattened()) {
- return FailedPrecondition(
- "Call graph must be flattened before copy insertion.");
- }
+ bool changed = false;
+ VLOG(2) << "CopyInsertion for module " << module->name();
- // Gather Ids of existing kCopy instructions in the module. We avoid removing
- // these copies (except via DCE in TupleSimplifier) because they may have been
- // added for reasons not considered by copy insertion (eg, layout assignment).
- // Instruction id is used instead of HloInstruction* because the pointer
- // values may be recycled.
- tensorflow::gtl::FlatSet<HloInstruction::Id> existing_copies;
- for (HloComputation* computation : module->computations()) {
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferLiveness> liveness,
+ BufferLiveness::Run(module, MakeUnique<DependencyHloOrdering>(module)));
+ const auto& points_to_analysis = liveness->points_to_analysis();
+ XLA_VLOG_LINES(2, points_to_analysis.ToString());
+ XLA_VLOG_LINES(2, module->ToString());
+
+ // Gather all while body computations and while instructions.
+ FlatSet<const HloComputation*> while_body_computations;
+ std::vector<HloInstruction*> while_instructions;
+ for (auto* computation : module->computations()) {
for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- existing_copies.insert(instruction->unique_id());
+ if (instruction->opcode() == HloOpcode::kWhile) {
+ while_body_computations.insert(instruction->while_body());
+ while_instructions.push_back(instruction);
}
}
}
- TF_RETURN_IF_ERROR(AddCopiesToResolveInterference(module));
-
- // Simplify the tuple structures introduced by the deep copies. This should be
- // done before removing copies (RemoveUnnecessaryCopies) because tuple
- // simplification changes dependencies in the graph which changes live range
- // interference in the graph. Also run DCE to remove the dead Tuple/GTE
- // instructions introduced by tuple simplification.
- TupleSimplifier tuple_simplifier;
- HloDCE dce;
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- TF_RETURN_IF_ERROR(dce.Run(module).status());
-
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
-
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
- DependencyHloOrdering ordering(module);
- TF_RETURN_IF_ERROR(
- RemoveUnnecessaryCopies(ordering, existing_copies, module));
-
- MaybeDumpModule("after removing unnecessary copies", *module);
-
- TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
-
- MaybeDumpModule("after adding special-case copies", *module);
+ // Collect instruction buffer indices to copy in 'instructions_to_copy'.
+ std::vector<InstructionCopier> instructions_to_copy;
+
+ // Add copies of computation root instructions, if needed.
+ FlatMap<const HloComputation*, ShapeTree<bool>> while_body_read_only_indices;
+ for (auto* computation : module->MakeNonfusionComputations()) {
+ VLOG(2) << "computation " << computation->name();
+ InstructionCopier root_copier(computation->root_instruction(),
+ /*copy_users=*/{});
+ if (while_body_computations.count(computation) > 0) {
+ // Record root indices to copy for while body sub-computations. We do not
+ // need to call RecordIndicesWhichPointToParamOrConstant for the while
+ // body root instruction here, because any necessary copies needed to
+ // avoid constants or parameters in the output are handled by while.init
+ // operand copy insertion below (which will share an allocation).
+ HloInstruction* while_body_param = computation->parameter_instruction(0);
+ ShapeTree<bool> read_only_indices(while_body_param->shape());
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesToCopyForColocatingBuffers(
+ *liveness, while_body_param, &read_only_indices));
+ while_body_read_only_indices[computation] = read_only_indices;
+
+ // Mark control predecessors, based on the body param, for any copies
+ // we'll be inserting. This ensures the copy doesn't run too early.
+ TF_RETURN_IF_ERROR(root_copier.RecordControlPredecessors(
+ points_to_analysis, while_body_param));
+ } else {
+ // Record root indices to copy for general computations.
+ TF_RETURN_IF_ERROR(root_copier.RecordIndicesWhichPointToParamOrConstant(
+ points_to_analysis));
+ }
+ instructions_to_copy.push_back(root_copier);
+ }
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- TF_RETURN_IF_ERROR(dce.Run(module).status());
- TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
+ // Add copies of while 'init' operand instructions, if needed. 'shared_copies'
+ // is used to ensure that multiple while loops can share a single copy of the
+ // same entry parameter or constant, if all loops use it read-only.
+ //
+ // TODO(b/33301720) Remove redundant while instruction copies.
+ FlatMap<const HloInstruction*, HloInstruction*> shared_copies;
+ for (HloInstruction* while_hlo : while_instructions) {
+ // Fix read_only_indices to account for entry constants. Also
+ // initialize copy_overrides, which ensures a single copy for each read-only
+ // constant that is used in multiple while loops.
+ ShapeTree<bool>* read_only_indices =
+ &while_body_read_only_indices[while_hlo->while_body()];
+ TF_ASSIGN_OR_RETURN(
+ const ShapeTree<HloInstruction*> copy_overrides,
+ RevertReadOnlyIndicesForConstants(while_hlo, points_to_analysis,
+ read_only_indices, &shared_copies));
+ // Create InstructionCopier for init operand of while instruction.
+ HloInstruction* init_hlo = while_hlo->mutable_operand(0);
+ InstructionCopier init_copier(init_hlo, {while_hlo});
+ init_copier.SetReadOnlyIndices(*read_only_indices);
+ init_copier.SetCopyOverrides(copy_overrides);
+ // Record 'init' buffer indices which point-to a Constant or Parameter.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesWhichPointToParamOrConstant(
+ points_to_analysis));
+ // Record indices necessary to colocate while and init operand buffers.
+ TF_RETURN_IF_ERROR(init_copier.RecordIndicesToCopyForColocatingBuffers(
+ *liveness, while_hlo, /*read_only_indices_out=*/nullptr));
+ instructions_to_copy.push_back(init_copier);
+ }
- MaybeDumpModule("after copy insertion", *module);
+ for (InstructionCopier& to_copy : instructions_to_copy) {
+ if (to_copy.HasAllIndicesFalse()) {
+ continue;
+ }
+ changed = true;
- if (VLOG_IS_ON(1)) {
- int64 num_total_copies = 0;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy) {
- num_total_copies++;
- }
- }
+ // Copy instruction at recorded buffer indices.
+ HloComputation* computation = to_copy.instruction()->parent();
+ HloInstruction* copy = to_copy.Copy();
+ if (to_copy.instruction() == computation->root_instruction()) {
+ computation->set_root_instruction(copy);
}
- VLOG(1) << "Num copies before copy-insertion: " << existing_copies.size();
- VLOG(1) << "Num copies after copy-insertion: " << num_total_copies;
}
- return true;
+ VLOG(3) << "After copy insertion for module " << module->name();
+ XLA_VLOG_LINES(3, module->ToString());
+
+ return changed;
}
} // namespace xla