aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc72
1 files changed, 44 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index 33d8338809..e0ce2e3555 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -472,6 +472,10 @@ class CopyRemover {
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
+ if (ShapeUtil::IsToken(value_a->shape())) {
+ // Token values have no representation and cannot interfere.
+ continue;
+ }
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
@@ -613,7 +617,10 @@ class CopyRemover {
VLOG(2) << copy->name() << " is not removable";
return false;
}
-
+ if (!ShapeUtil::Equal(copy->shape(), copy->operand(0)->shape())) {
+ VLOG(2) << copy->name() << " is not removable (shape mismatch)";
+ return false;
+ }
const CopyNodes& copy_node = copy_map_.at(copy);
ValueNode* src = copy_node.src;
ValueNode* dest = copy_node.dest;
@@ -947,28 +954,6 @@ class CopyRemover {
BufferValueTracker buffer_value_tracker_;
};
-// Try to remove as many copies from the module as possible without introducing
-// live range interference. Copy instructions (identified by their unique id) in
-// the set copies_to_exclude are not considered for removal.
-Status RemoveUnnecessaryCopies(
- const HloOrdering& ordering,
- const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
- TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
- HloAliasAnalysis::Run(module));
- CopyRemover copy_remover(*alias_analysis, ordering, module);
- XLA_VLOG_LINES(3, copy_remover.ToString());
-
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- !ContainsKey(copies_to_exclude, instruction->unique_id())) {
- TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
- }
- }
- }
- return Status::OK();
-}
-
// Add copies to address special constraints on the roots of computations not
// related to live range interference:
//
@@ -1065,13 +1050,23 @@ Status AddSpecialCaseCopies(const CallGraph& call_graph, HloModule* module) {
HloInstruction* instruction = pair.first;
const ShapeTree<bool>& indices_to_copy = pair.second;
+ ShapeTree<HloInstruction*> copies_added(indices_to_copy.shape());
std::vector<HloInstruction*> users = instruction->users();
TF_ASSIGN_OR_RETURN(HloInstruction * deep_copy,
instruction->parent()->DeepCopyInstruction(
- instruction, &indices_to_copy));
+ instruction, &indices_to_copy, &copies_added));
for (HloInstruction* user : users) {
TF_RETURN_IF_ERROR(instruction->ReplaceUseWith(user, deep_copy));
}
+ // Special case copies are not eligible for later copy elision passes.
+ indices_to_copy.ForEachElement([&](const ShapeIndex& index, bool has_copy) {
+ if (has_copy) {
+ HloInstruction* copy = *copies_added.mutable_element(index);
+ if (copy != nullptr) {
+ copy->SetCopyElisionAllowed(false);
+ }
+ }
+ });
if (instruction == instruction->parent()->root_instruction()) {
instruction->parent()->set_root_instruction(deep_copy);
}
@@ -1097,6 +1092,31 @@ void MaybeDumpModule(const string& message, const HloModule& module) {
} // namespace
+Status RemoveUnnecessaryCopies(
+ const HloOrdering& ordering,
+ const tensorflow::gtl::FlatSet<int>& copies_to_exclude, HloModule* module) {
+ MaybeDumpModule("after adding copies to resolve interference", *module);
+
+ TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
+ HloAliasAnalysis::Run(module));
+ CopyRemover copy_remover(*alias_analysis, ordering, module);
+ XLA_VLOG_LINES(3, copy_remover.ToString());
+
+ std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
+ for (HloComputation* computation : module->computations()) {
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kCopy &&
+ !ContainsKey(copies_to_exclude, instruction->unique_id()) &&
+ instruction->CopyElisionAllowed()) {
+ TF_RETURN_IF_ERROR(copy_remover.TryElideCopy(instruction).status());
+ }
+ }
+ }
+ MaybeDumpModule("after removing unnecessary copies", *module);
+
+ return Status::OK();
+}
+
StatusOr<bool> CopyInsertion::Run(HloModule* module) {
// Copy insertion is performed in three steps:
//
@@ -1158,14 +1178,10 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
TF_DCHECK_OK(VerifyNoLiveRangeInterference(module));
- MaybeDumpModule("after adding copies to resolve interference", *module);
-
DependencyHloOrdering ordering(module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(ordering, existing_copies, module));
- MaybeDumpModule("after removing unnecessary copies", *module);
-
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
MaybeDumpModule("after adding special-case copies", *module);