From 89013c6f76568736cd6d8395f73db53045303412 Mon Sep 17 00:00:00 2001 From: Yuanzhong Xu Date: Mon, 25 Jun 2018 17:27:30 -0700 Subject: [XLA] Avoid fusion nodes to have duplicate operands during replacing uses. PiperOrigin-RevId: 202049336 --- tensorflow/compiler/xla/service/hlo_computation.cc | 69 +++++++++++++++++----- 1 file changed, 54 insertions(+), 15 deletions(-) (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc') diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c057be8201..34b18b0e21 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for -- cgit v1.2.3