diff options
author | 2017-04-20 12:52:35 -0800 | |
---|---|---|
committer | 2017-04-20 14:12:17 -0700 | |
commit | c0088ae3d2541d8e00fc238377dd802a811624f3 (patch) | |
tree | c3eb13ecd1bf4c244d6ca719c15c0bac8563b63b /tensorflow/compiler/xla/service | |
parent | 09097763dc5d578e972b2281c7863eefb0b44522 (diff) |
[XLA] Fix the parameter instruction printing issue
Append the parameter number to the fusion parameter name, and use the parameter
name rather the instruction name in creating the new parameter.
Show the paramameter number when printing out parameter instructions.
Change: 153752424
Diffstat (limited to 'tensorflow/compiler/xla/service')
4 files changed, 19 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc index 526c1c58b3..0fb8c06f88 100644 --- a/tensorflow/compiler/xla/service/hlo_computation.cc +++ b/tensorflow/compiler/xla/service/hlo_computation.cc @@ -93,6 +93,10 @@ HloInstruction* HloComputation::AddInstructionInternal( // Generate a unique name for the instruction. instruction->set_name( instruction_name_uniquer_.GetUniqueName(instruction->name())); + if (instruction->opcode() == HloOpcode::kParameter) { + instruction->set_parameter_name( + instruction_name_uniquer_.GetUniqueName(instruction->parameter_name())); + } Reparent(instruction.get()); HloInstruction* pinst = instruction.get(); instruction_iterators_[pinst] = @@ -127,9 +131,9 @@ Status HloComputation::RemoveParameter(int64 param_no) { while (param_no < param_instructions_.size()) { param_instruction = param_instructions_[param_no]; - HloInstruction* new_instr = - AddInstructionInternal(HloInstruction::CreateParameter( - param_no, param_instruction->shape(), param_instruction->name())); + HloInstruction* new_instr = AddInstructionInternal( + HloInstruction::CreateParameter(param_no, param_instruction->shape(), + param_instruction->parameter_name())); TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr)); new_instr->SetParentFusion(root_instruction_->fusion_instruction()); param_instructions_[param_no] = new_instr; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 1ede4e963f..930a4fdf1b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -577,8 +577,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal( // instruction. Add it as an operand and add a corresponding fused // parameter instruction. int64 param_no = fused_parameters_.size(); - std::unique_ptr<HloInstruction> param_instruction = - CreateParameter(param_no, operand->shape(), "fusion_param"); + std::unique_ptr<HloInstruction> param_instruction = CreateParameter( + param_no, operand->shape(), + tensorflow::strings::StrCat("fusion_param.", param_no)); param_instruction->parent_fusion_instruction_ = this; fused_param = fused_instructions_computation_->AddParameter( @@ -1421,6 +1422,8 @@ string HloInstruction::ToString(bool compact_operands) const { // Do not show large constants. operands = "{...}"; } + } else if (opcode() == HloOpcode::kParameter) { + operands = Printf("%lld", parameter_number_); } else { tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_); const int64 kMaxOperandsToShowIfCompact = 4; @@ -1482,6 +1485,7 @@ string HloInstruction::ToString(bool compact_operands) const { !metadata_.source_file().empty()) { StrAppend(&extra, " # metadata=", metadata_.ShortDebugString()); } + return Printf("%s = %s %s(%s)%s", name().c_str(), ShapeUtil::HumanStringWithLayout(shape()).c_str(), ExtendedOpcodeStr().c_str(), operands.c_str(), extra.c_str()); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index f57b27f605..4647c2ca85 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -421,6 +421,11 @@ class HloInstruction { return parameter_name_; } + void set_parameter_name(const string& str) { + CHECK_EQ(HloOpcode::kParameter, opcode_); + parameter_name_ = str; + } + // Returns the dimension sizes or numbers associated with this instruction. // // Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape, diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc index 4ff83c2de7..1465d1cacd 100644 --- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc +++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc @@ -62,7 +62,7 @@ TEST(HloMatchersTest, Test) { "doesn't match expected:\n\t" "multiply(add, add), \n" "operand 0:\n\t" - "%param = f32[1]{0} parameter()\n" + "%param = f32[1]{0} parameter(0)\n" "doesn't match expected:\n\t" "add")); } |