aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_computation.cc
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-06-25 17:27:30 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-06-25 17:30:25 -0700
commit89013c6f76568736cd6d8395f73db53045303412 (patch)
treeaa16b75243df34b120eb0799d8c7534b494fcfad /tensorflow/compiler/xla/service/hlo_computation.cc
parentee8703f342269dca881c17c6db3177355fcd18c7 (diff)
[XLA] Avoid fusion nodes to have duplicate operands during replacing uses.
PiperOrigin-RevId: 202049336
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_computation.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc69
1 files changed, 54 insertions, 15 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c057be8201..34b18b0e21 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -120,6 +120,30 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
+namespace {
+
+// Returns the new name for a fusion parameter when we change its number.
+//
+// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
+// renumbering the parameters, so replace the final number in the name with
+// the updated value.
+string RenameFusionParameter(const string& original_name, int64 new_param_no) {
+ const string param_underscore = ".param_";
+ size_t index = original_name.rfind(param_underscore);
+ if (index == string::npos) {
+ return original_name;
+ }
+ string after_param = original_name.substr(index + param_underscore.size());
+ int64 numeric_suffix;
+ if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
+ return StrCat(original_name.substr(0, index + param_underscore.size()),
+ new_param_no);
+ }
+ return original_name;
+}
+
+} // namespace
+
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -132,21 +156,8 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name = param_instruction->name();
- // Fusion parameters are named foo.param_1, bar.param_2, etc. We are
- // renumbering the parameters, so replace the final number in the name with
- // the updated value.
- const string param_underscore = ".param_";
- size_t index = param_name.rfind(param_underscore);
- if (index == string::npos) {
- string after_param = name().substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (tensorflow::strings::safe_strto64(after_param, &numeric_suffix)) {
- param_name =
- StrCat(param_name.substr(0, index), param_underscore, param_no);
- }
- }
-
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
param_no, param_instruction->shape(), param_name));
@@ -159,6 +170,34 @@ Status HloComputation::RemoveParameter(int64 param_no) {
return Status::OK();
}
+Status HloComputation::RemoveUnusedParameters() {
+ CHECK(IsFusionComputation());
+ int64 removed = 0;
+ for (int64 i = 0; i < param_instructions_.size(); ++i) {
+ HloInstruction* param_instruction = param_instructions_[i];
+ if (param_instruction->user_count() == 0 &&
+ param_instruction != root_instruction()) {
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ ++removed;
+ continue;
+ }
+
+ if (removed > 0) {
+ const int64 param_no = i - removed;
+ string param_name =
+ RenameFusionParameter(param_instruction->name(), param_no);
+ HloInstruction* new_instr =
+ AddInstructionInternal(HloInstruction::CreateParameter(
+ param_no, param_instruction->shape(), param_name));
+ TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
+ param_instructions_[param_no] = new_instr;
+ TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
+ }
+ }
+ param_instructions_.resize(param_instructions_.size() - removed);
+ return Status::OK();
+}
+
bool HloComputation::IsRemovable(const HloInstruction* instruction) {
// If the instruction has control predecessors or successors then we cannot
// remove the instruction without violating ordering constraints (added, for