diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/copy_insertion.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/copy_insertion.cc | 85 |
1 files changed, 3 insertions, 82 deletions
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc index cfe025fdd1..f35324aa35 100644 --- a/tensorflow/compiler/xla/service/copy_insertion.cc +++ b/tensorflow/compiler/xla/service/copy_insertion.cc @@ -40,12 +40,10 @@ namespace { using absl::StrAppend; -bool IsReadonlyEntryParameterValue(const HloValue& value) { +bool IsEntryParameterValue(const HloValue& value) { const HloComputation* computation = value.defining_instruction()->parent(); return value.defining_instruction()->opcode() == HloOpcode::kParameter && - computation == computation->parent()->entry_computation() && - !computation->parent()->input_output_alias_config().ParameterHasAlias( - value.defining_instruction()->parameter_number()); + computation == computation->parent()->entry_computation(); } bool IsConstantValue(const HloValue& value) { @@ -53,7 +51,7 @@ bool IsConstantValue(const HloValue& value) { } bool ValueIsReadOnly(const HloValue& value) { - return IsConstantValue(value) || IsReadonlyEntryParameterValue(value); + return IsConstantValue(value) || IsEntryParameterValue(value); } // Data structure describing the action which should be taken on parts of a @@ -334,81 +332,6 @@ Status AddCopiesForConditional(const HloAliasAnalysis& alias_analysis, return Status::OK(); } -// Conservatively adds copies before root instruction of entry computation and -// each aliased parameter to resolve interference of aliased input and output -// buffer. We later rely on the CopyRemover to drop the unnecessary ones. -Status AddCopiesForAliasedInputOutputs(HloModule* module) { - HloComputation* entry = module->entry_computation(); - HloInstruction* root = entry->root_instruction(); - - ShapeTree<bool> output_indices_to_copy(root->shape()); - std::vector<ShapeTree<HloInstruction*>> copied_parameters; - bool has_alias = false; - for (auto* param : entry->parameter_instructions()) { - bool param_has_alias = false; - ShapeTree<bool> param_indices_to_copy(param->shape()); - - module->input_output_alias_config().ForEachAlias( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& param_index) { - if (param_number == param->parameter_number()) { - param_has_alias = true; - *(param_indices_to_copy.mutable_element(param_index)) = true; - *(output_indices_to_copy.mutable_element(output_index)) = true; - } - }); - - if (!param_has_alias) { - continue; - } - - has_alias = true; - // Store a snapshot of users before DeepCopyInstruction, as - // DeepCopyInstruction introduces new users of the instruction. - std::vector<HloInstruction*> users = param->users(); - ShapeTree<HloInstruction*> param_copy_tree(param->shape(), - /*init_value=*/nullptr); - TF_ASSIGN_OR_RETURN(HloInstruction * copied, - entry->DeepCopyInstruction( - param, ¶m_indices_to_copy, ¶m_copy_tree)); - for (HloInstruction* user : users) { - TF_RETURN_IF_ERROR(param->ReplaceUseWith(user, copied)); - } - - copied_parameters.push_back(param_copy_tree); - } - - if (!has_alias) { - return Status::OK(); - } - - // Add copies before root instruction. - ShapeTree<HloInstruction*> output_copy_tree(root->shape(), - /*init_value=*/nullptr); - - TF_ASSIGN_OR_RETURN(HloInstruction * root_copied, - root->parent()->DeepCopyInstruction( - root, &output_indices_to_copy, &output_copy_tree)); - - // Add control dependencies between the input/output copies. - TF_RETURN_IF_ERROR(module->input_output_alias_config().ForEachAliasWithStatus( - [&](const ShapeIndex& output_index, int64 param_number, - const ShapeIndex& input_index) -> Status { - HloInstruction* from = - copied_parameters[param_number].element(input_index); - HloInstruction* to = output_copy_tree.element(output_index); - - TF_RET_CHECK(from != nullptr); - TF_RET_CHECK(to != nullptr); - TF_RETURN_IF_ERROR(from->AddControlDependencyTo(to)); - return Status::OK(); - })); - - entry->set_root_instruction(root_copied); - - return Status::OK(); -} - // Removes any control dependencies to or from the given instruction. Status StripControlDependenciesFrom(HloInstruction* instruction) { while (!instruction->control_successors().empty()) { @@ -1030,8 +953,6 @@ Status CopyInsertion::AddCopiesToResolveInterference(HloModule* module) { } } } - - TF_RETURN_IF_ERROR(AddCopiesForAliasedInputOutputs(module)); return Status::OK(); } |