aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/copy_insertion.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc85
1 files changed, 82 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index f35324aa35..cfe025fdd1 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -40,10 +40,12 @@ namespace {
using absl::StrAppend;
-bool IsEntryParameterValue(const HloValue& value) {
+bool IsReadonlyEntryParameterValue(const HloValue& value) {
const HloComputation* computation = value.defining_instruction()->parent();
return value.defining_instruction()->opcode() == HloOpcode::kParameter &&
- computation == computation->parent()->entry_computation();
+ computation == computation->parent()->entry_computation() &&
+ !computation->parent()->input_output_alias_config().ParameterHasAlias(
+ value.defining_instruction()->parameter_number());
}
bool IsConstantValue(const HloValue& value) {
@@ -51,7 +53,7 @@ bool IsConstantValue(const HloValue& value) {
}
bool ValueIsReadOnly(const HloValue& value) {
- return IsConstantValue(value) || IsEntryParameterValue(value);
+ return IsConstantValue(value) || IsReadonlyEntryParameterValue(value);
}
// Data structure describing the action which should be taken on parts of a
@@ -332,6 +334,81 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis,
return Status::OK();
}
+// Conservatively adds copies before root instruction of entry computation and
+// each aliased parameter to resolve interference of aliased input and output
+// buffer. We later rely on the CopyRemover to drop the unnecessary ones.
+Status AddCopiesForAliasedInputOutputs(HloModule* module) {
+ HloComputation* entry = module->entry_computation();
+ HloInstruction* root = entry->root_instruction();
+
+ ShapeTree<bool> output_indices_to_copy(root->shape());
+ std::vector<ShapeTree<HloInstruction*>> copied_parameters;
+ bool has_alias = false;
+ for (auto* param : entry->parameter_instructions()) {
+ bool param_has_alias = false;
+ ShapeTree<bool> param_indices_to_copy(param->shape());
+
+ module->input_output_alias_config().ForEachAlias(
+ [&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& param_index) {
+ if (param_number == param->parameter_number()) {
+ param_has_alias = true;
+ *(param_indices_to_copy.mutable_element(param_index)) = true;
+ *(output_indices_to_copy.mutable_element(output_index)) = true;
+ }
+ });
+
+ if (!param_has_alias) {
+ continue;
+ }
+
+ has_alias = true;
+ // Store a snapshot of users before DeepCopyInstruction, as
+ // DeepCopyInstruction introduces new users of the instruction.
+ std::vector<HloInstruction*> users = param->users();
+ ShapeTree<HloInstruction*> param_copy_tree(param->shape(),
+ /*init_value=*/nullptr);
+ TF_ASSIGN_OR_RETURN(HloInstruction * copied,
+ entry->DeepCopyInstruction(
+ param, &param_indices_to_copy, &param_copy_tree));
+ for (HloInstruction* user : users) {
+ TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied));
+ }
+
+ copied_parameters.push_back(param_copy_tree);
+ }
+
+ if (!has_alias) {
+ return Status::OK();
+ }
+
+ // Add copies before root instruction.
+ ShapeTree<HloInstruction*> output_copy_tree(root->shape(),
+ /*init_value=*/nullptr);
+
+ TF_ASSIGN_OR_RETURN(HloInstruction * root_copied,
+ root->parent()->DeepCopyInstruction(
+ root, &output_indices_to_copy, &output_copy_tree));
+
+ // Add control dependencies between the input/output copies.
+ TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus(
+ [&](const ShapeIndex& output_index, int64 param_number,
+ const ShapeIndex& input_index) -> Status {
+ HloInstruction* from =
+ copied_parameters[param_number].element(input_index);
+ HloInstruction* to = output_copy_tree.element(output_index);
+
+ TF_RET_CHECK(from != nullptr);
+ TF_RET_CHECK(to != nullptr);
+ TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to));
+ return Status::OK();
+ }));
+
+ entry->set_root_instruction(root_copied);
+
+ return Status::OK();
+}
+
// Removes any control dependencies to or from the given instruction.
Status StripControlDependenciesFrom(HloInstruction* instruction) {
while (!instruction->control_successors().empty()) {
@@ -953,6 +1030,8 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) {
}
}
}
+
+ TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module));
return Status::OK();
}