diff options
Diffstat (limited to 'tensorflow/compiler/xla/service')
7 files changed, 162 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index c057be8201..34b18b0e21 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter( return instructions_.back().get(); } +namespace { + +// Returns the new name for a fusion parameter when we change its number. +// +// Fusion parameters are named foo.param_1, bar.param_2, etc. We are +// renumbering the parameters, so replace the final number in the name with +// the updated value. +string RenameFusionParameter(const string& original_name, int64 new_param_no) { + const string param_underscore = ".param_"; + size_t index = original_name.rfind(param_underscore); + if (index == string::npos) { + return original_name; + } + string after_param = original_name.substr(index + param_underscore.size()); + int64 numeric_suffix; + if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { + return StrCat(original_name.substr(0, index + param_underscore.size()), + new_param_no); + } + return original_name; +} + +} // namespace + Status HloComputation::RemoveParameter(int64 param_no) { CHECK_GE(param_no, 0); CHECK_LT(param_no, param_instructions_.size()); @@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - string param_name = param_instruction->name(); - // Fusion parameters are named foo.param_1, bar.param_2, etc. We are - // renumbering the parameters, so replace the final number in the name with - // the updated value. - const string param_underscore = ".param_"; - size_t index = param_name.rfind(param_underscore); - if (index == string::npos) { - string after_param = name().substr(index + param_underscore.size()); - int64 numeric_suffix; - if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) { - param_name = - StrCat(param_name.substr(0, index), param_underscore, param_no); - } - } - + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); HloInstruction* new_instr = AddInstructionInternal(HloInstruction::CreateParameter( param_no, param_instruction->shape(), param_name)); @@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) { return Status::OK(); } +Status HloComputation::RemoveUnusedParameters() { + CHECK(IsFusionComputation()); + int64 removed = 0; + for (int64 i = 0; i < param_instructions_.size(); ++i) { + HloInstruction* param_instruction = param_instructions_[i]; + if (param_instruction->user_count() == 0 && + param_instruction != root_instruction()) { + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + ++removed; + continue; + } + + if (removed > 0) { + const int64 param_no = i - removed; + string param_name = + RenameFusionParameter(param_instruction->name(), param_no); + HloInstruction* new_instr = + AddInstructionInternal(HloInstruction::CreateParameter( + param_no, param_instruction->shape(), param_name)); + TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); + param_instructions_[param_no] = new_instr; + TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction)); + } + } + param_instructions_.resize(param_instructions_.size() - removed); + return Status::OK(); +} + bool HloComputation::IsRemovable(const HloInstruction* instruction) { // If the instruction has control predecessors or successors then we cannot // remove the instruction without violating ordering constraints (added, for diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h index 0f111a1a76..c1c3e79ebc 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.h +++ b/tensorflow/compiler/xla/service/hlo_computation.h @@ -113,6 +113,11 @@ class HloComputation { // instruction. Status RemoveParameter(int64 param_no); + // Remove unused parameters from the computation. + // Note this is only applicatable to the computation for the fusion + // instruction. + Status RemoveUnusedParameters(); + // Add new parameter instruction to the computation. // This should be a new parameter. Instruction will be appended to parameters // and inserted to the instruction list. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index dfc2fbe87f..8be64a6881 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1399,6 +1399,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) { operand->AddUser(this); } +void HloInstruction::RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice<int> ascending_indices) { + if (ascending_indices.empty()) { + return; + } + int next_index = 0; + int removed_count = 0; + for (int to_remove : ascending_indices) { + while (next_index < to_remove) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_LT(to_remove, operands_.size()); + ++removed_count; + ++next_index; + } + while (next_index < operands_.size()) { + operands_[next_index - removed_count] = operands_[next_index]; + ++next_index; + } + CHECK_EQ(removed_count, ascending_indices.size()); + operands_.resize(operands_.size() - removed_count); +} + void HloInstruction::AddUser(HloInstruction* user) { if (!ContainsKey(user_set_, user)) { user_set_.insert(user); @@ -1568,6 +1592,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user, std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands()); + } return Status::OK(); } @@ -1606,6 +1634,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) { std::replace(user->operands_.begin(), user->operands_.end(), this, new_producer); new_producer->AddUser(user); + if (user->opcode() == HloOpcode::kFusion) { + TF_RETURN_IF_ERROR( + Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands()); + } } } users_.clear(); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 4a0772159e..55668cf1a2 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -826,9 +826,15 @@ class HloInstruction { // Replaces the use of this instruction in "user" with "new_producer". Note // that there might be multiple uses of this instruction in "user"; all will // be replaced. + // + // If user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer); // Replaces the specified operand with new_operand. + // + // This function does NOT remove duplicated operands even if this instruction + // is a fusion, so that the existing operand numbers do not change. Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand); // Replaces all uses of this instruction with the new producer. If @@ -837,6 +843,9 @@ class HloInstruction { // // If this instruction is the root of its computation, sets the computation's // root to new_producer. + // + // If a user is a fusion instruction, this function will remove any duplicated + // operands of it which could be created due to this replacement. Status ReplaceAllUsesWith(HloInstruction* new_producer); // Performs a postorder DFS visit using this node as the root. If @@ -1455,6 +1464,10 @@ class HloInstruction { operands_.erase(operands_.begin() + index); } + // Removes a list of operands with the given indices in ascending order. + void RemoveOperandsAtAscendingIndices( + tensorflow::gtl::ArraySlice<int> ascending_indices); + void AppendComputation(HloComputation* computation) { called_computations_.push_back(computation); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc index 120162a956..3847d68efa 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc @@ -1137,6 +1137,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) { EXPECT_TRUE(StructuralEqual(*fusion, *fusion2)); } +TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) { + // Fused expression: + // + // x y + // | | + // | transpose + // \ / + // dot + const Shape s = ShapeUtil::MakeShape(F32, {10, 10}); + + HloComputation::Builder builder("TransposeDot"); + HloInstruction* x = + builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x")); + HloInstruction* y = + builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y")); + HloInstruction* reshape = + builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0})); + DotDimensionNumbers dot_dnums; + dot_dnums.add_lhs_contracting_dimensions(1); + dot_dnums.add_rhs_contracting_dimensions(0); + HloInstruction* dot = builder.AddInstruction( + HloInstruction::CreateDot(s, x, reshape, dot_dnums)); + + auto module = CreateNewModule(); + auto* computation = module->AddEntryComputation(builder.Build()); + HloInstruction* fusion = computation->CreateFusionInstruction( + {dot, reshape}, HloInstruction::FusionKind::kLoop); + + EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok()); + + EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y)); + EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1); +} + TEST_F(HloInstructionTest, FusionEquality) { auto module = CreateNewModule(); HloComputation::Builder builder(TestName()); diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index a015d791ce..e2f43f5810 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_module.h" #include "tensorflow/compiler/xla/window_util.h" +#include "tensorflow/core/lib/gtl/flatmap.h" namespace xla { namespace { @@ -1208,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl( new_fused_computation); } +Status HloFusionInstruction::DeduplicateFusionOperands() { + tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices; + std::vector<int> operands_to_remove; + for (int i = 0; i < operand_count(); ++i) { + auto emplace_result = operand_indices.emplace(operand(i), i); + if (!emplace_result.second) { + TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith( + fused_parameter(emplace_result.first->second))); + operands_to_remove.push_back(i); + } + } + if (operands_to_remove.empty()) { + return Status::OK(); + } + TF_RETURN_IF_ERROR( + fused_instructions_computation()->RemoveUnusedParameters()); + RemoveOperandsAtAscendingIndices(operands_to_remove); + return Status::OK(); +} + HloRngInstruction::HloRngInstruction( const Shape& shape, RandomDistribution distribution, tensorflow::gtl::ArraySlice<HloInstruction*> parameters) diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 875860a8cc..ec8a42bd3b 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -635,6 +635,9 @@ class HloFusionInstruction : public HloInstruction { void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; } + // If multiple operands are the same instruction, keeps only one of them. + Status DeduplicateFusionOperands(); + private: // Fuses the given instruction into this fusion instruction. When add_output // is false (which is the default), instruction_to_fuse is cloned and the |