aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc69
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc32
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc34
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h3
7 files changed, 162 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c057be8201..34b18b0e21 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
+namespace {
+
+// Returns the new name for a fusion parameter when we change its number.
+//
+// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
+// renumbering the parameters, so replace the final number in the name with
+// the updated value.
+string RenameFusionParameter(const string& original_name, int64 new_param_no) {
+ const string param_underscore = ".param_";
+ size_t index = original_name.rfind(param_underscore);
+ if (index == string::npos) {
+ return original_name;
+ }
+ string after_param = original_name.substr(index + param_underscore.size());
+ int64 numeric_suffix;
+ if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ return StrCat(original_name.substr(0, index + param_underscore.size()),
+ new_param_no);
+ }
+ return original_name;
+}
+
+} // namespace
+
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name = param_instruction->name();
- // Fusion parameters are named foo.param_1, bar.param_2, etc. We are
- // renumbering the parameters, so replace the final number in the name with
- // the updated value.
- const string param_underscore = ".param_";
- size_t index = param_name.rfind(param_underscore);
- if (index == string::npos) {
- string after_param = name().substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
- param_name =
- StrCat(param_name.substr(0, index), param_underscore, param_no);
- }
- }
-
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
@@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) {
return Status::OK();
}
+Status HloComputation::RemoveUnusedParameters() {
+ CHECK(IsFusionComputation());
+ int64 removed = 0;
+ for (int64 i = 0; i < param_instructions_.size(); ++i) {
+ HloInstruction* param_instruction = param_instructions_[i];
+ if (param_instruction->user_count() == 0 &&
+ param_instruction != root_instruction()) {
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ ++removed;
+ continue;
+ }
+
+ if (removed > 0) {
+ const int64 param_no = i - removed;
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_name));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ }
+ }
+ param_instructions_.resize(param_instructions_.size() - removed);
+ return Status::OK();
+}
+
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 0f111a1a76..c1c3e79ebc 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -113,6 +113,11 @@ class HloComputation {
// instruction.
Status RemoveParameter(int64 param_no);
+ // Remove unused parameters from the computation.
+ // Note this is only applicatable to the computation for the fusion
+ // instruction.
+ Status RemoveUnusedParameters();
+
// Add new parameter instruction to the computation.
// This should be a new parameter. Instruction will be appended to parameters
// and inserted to the instruction list.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index dfc2fbe87f..8be64a6881 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1399,6 +1399,30 @@ void HloInstruction::AppendOperand(HloInstruction* operand) {
operand->AddUser(this);
}
+void HloInstruction::RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices) {
+ if (ascending_indices.empty()) {
+ return;
+ }
+ int next_index = 0;
+ int removed_count = 0;
+ for (int to_remove : ascending_indices) {
+ while (next_index < to_remove) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_LT(to_remove, operands_.size());
+ ++removed_count;
+ ++next_index;
+ }
+ while (next_index < operands_.size()) {
+ operands_[next_index - removed_count] = operands_[next_index];
+ ++next_index;
+ }
+ CHECK_EQ(removed_count, ascending_indices.size());
+ operands_.resize(operands_.size() - removed_count);
+}
+
void HloInstruction::AddUser(HloInstruction* user) {
if (!ContainsKey(user_set_, user)) {
user_set_.insert(user);
@@ -1568,6 +1592,10 @@ Status HloInstruction::ReplaceUseWith(HloInstruction* user,
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
return Status::OK();
}
@@ -1606,6 +1634,10 @@ Status HloInstruction::ReplaceAllUsesWith(HloInstruction* new_producer) {
std::replace(user->operands_.begin(), user->operands_.end(), this,
new_producer);
new_producer->AddUser(user);
+ if (user->opcode() == HloOpcode::kFusion) {
+ TF_RETURN_IF_ERROR(
+ Cast<HloFusionInstruction>(user)->DeduplicateFusionOperands());
+ }
}
}
users_.clear();
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 4a0772159e..55668cf1a2 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -826,9 +826,15 @@ class HloInstruction {
// Replaces the use of this instruction in "user" with "new_producer". Note
// that there might be multiple uses of this instruction in "user"; all will
// be replaced.
+ //
+ // If user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
// Replaces the specified operand with new_operand.
+ //
+ // This function does NOT remove duplicated operands even if this instruction
+ // is a fusion, so that the existing operand numbers do not change.
Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
// Replaces all uses of this instruction with the new producer. If
@@ -837,6 +843,9 @@ class HloInstruction {
//
// If this instruction is the root of its computation, sets the computation's
// root to new_producer.
+ //
+ // If a user is a fusion instruction, this function will remove any duplicated
+ // operands of it which could be created due to this replacement.
Status ReplaceAllUsesWith(HloInstruction* new_producer);
// Performs a postorder DFS visit using this node as the root. If
@@ -1455,6 +1464,10 @@ class HloInstruction {
operands_.erase(operands_.begin() + index);
}
+ // Removes a list of operands with the given indices in ascending order.
+ void RemoveOperandsAtAscendingIndices(
+ tensorflow::gtl::ArraySlice<int> ascending_indices);
+
void AppendComputation(HloComputation* computation) {
called_computations_.push_back(computation);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 120162a956..3847d68efa 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1137,6 +1137,40 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
EXPECT_TRUE(StructuralEqual(*fusion, *fusion2));
}
+TEST_F(HloInstructionTest, NoRedundantFusionOperandsAfterReplacingUse) {
+ // Fused expression:
+ //
+ // x y
+ // | |
+ // | transpose
+ // \ /
+ // dot
+ const Shape s = ShapeUtil::MakeShape(F32, {10, 10});
+
+ HloComputation::Builder builder("TransposeDot");
+ HloInstruction* x =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, s, "x"));
+ HloInstruction* y =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, s, "y"));
+ HloInstruction* reshape =
+ builder.AddInstruction(HloInstruction::CreateTranspose(s, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(s, x, reshape, dot_dnums));
+
+ auto module = CreateNewModule();
+ auto* computation = module->AddEntryComputation(builder.Build());
+ HloInstruction* fusion = computation->CreateFusionInstruction(
+ {dot, reshape}, HloInstruction::FusionKind::kLoop);
+
+ EXPECT_TRUE(x->ReplaceAllUsesWith(y).ok());
+
+ EXPECT_THAT(fusion->operands(), UnorderedElementsAre(y));
+ EXPECT_EQ(fusion->fused_instructions_computation()->num_parameters(), 1);
+}
+
TEST_F(HloInstructionTest, FusionEquality) {
auto module = CreateNewModule();
HloComputation::Builder builder(TestName());
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index a015d791ce..e2f43f5810 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/window_util.h"
+#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -1208,6 +1209,26 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
new_fused_computation);
}
+Status HloFusionInstruction::DeduplicateFusionOperands() {
+ tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ std::vector<int> operands_to_remove;
+ for (int i = 0; i < operand_count(); ++i) {
+ auto emplace_result = operand_indices.emplace(operand(i), i);
+ if (!emplace_result.second) {
+ TF_RETURN_IF_ERROR(fused_parameter(i)->ReplaceAllUsesWith(
+ fused_parameter(emplace_result.first->second)));
+ operands_to_remove.push_back(i);
+ }
+ }
+ if (operands_to_remove.empty()) {
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(
+ fused_instructions_computation()->RemoveUnusedParameters());
+ RemoveOperandsAtAscendingIndices(operands_to_remove);
+ return Status::OK();
+}
+
HloRngInstruction::HloRngInstruction(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters)
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 875860a8cc..ec8a42bd3b 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -635,6 +635,9 @@ class HloFusionInstruction : public HloInstruction {
void set_fusion_kind(FusionKind kind) { fusion_kind_ = kind; }
+ // If multiple operands are the same instruction, keeps only one of them.
+ Status DeduplicateFusionOperands();
+
private:
// Fuses the given instruction into this fusion instruction. When add_output
// is false (which is the default), instruction_to_fuse is cloned and the