aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-04-20 12:52:35 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-04-20 14:12:17 -0700
commitc0088ae3d2541d8e00fc238377dd802a811624f3 (patch)
treec3eb13ecd1bf4c244d6ca719c15c0bac8563b63b /tensorflow/compiler/xla/service
parent09097763dc5d578e972b2281c7863eefb0b44522 (diff)
[XLA] Fix the parameter instruction printing issue
Append the parameter number to the fusion parameter name, and use the parameter name rather the instruction name in creating the new parameter. Show the paramameter number when printing out parameter instructions. Change: 153752424
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc2
4 files changed, 19 insertions, 6 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 526c1c58b3..0fb8c06f88 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -93,6 +93,10 @@ HloInstruction* HloComputation::AddInstructionInternal(
// Generate a unique name for the instruction.
instruction->set_name(
instruction_name_uniquer_.GetUniqueName(instruction->name()));
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ instruction->set_parameter_name(
+ instruction_name_uniquer_.GetUniqueName(instruction->parameter_name()));
+ }
Reparent(instruction.get());
HloInstruction* pinst = instruction.get();
instruction_iterators_[pinst] =
@@ -127,9 +131,9 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- HloInstruction* new_instr =
- AddInstructionInternal(HloInstruction::CreateParameter(
- param_no, param_instruction->shape(), param_instruction->name()));
+ HloInstruction* new_instr = AddInstructionInternal(
+ HloInstruction::CreateParameter(param_no, param_instruction->shape(),
+ param_instruction->parameter_name()));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
new_instr->SetParentFusion(root_instruction_->fusion_instruction());
param_instructions_[param_no] = new_instr;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 1ede4e963f..930a4fdf1b 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -577,8 +577,9 @@ HloInstruction* HloInstruction::CloneAndFuseInternal(
// instruction. Add it as an operand and add a corresponding fused
// parameter instruction.
int64 param_no = fused_parameters_.size();
- std::unique_ptr<HloInstruction> param_instruction =
- CreateParameter(param_no, operand->shape(), "fusion_param");
+ std::unique_ptr<HloInstruction> param_instruction = CreateParameter(
+ param_no, operand->shape(),
+ tensorflow::strings::StrCat("fusion_param.", param_no));
param_instruction->parent_fusion_instruction_ = this;
fused_param = fused_instructions_computation_->AddParameter(
@@ -1421,6 +1422,8 @@ string HloInstruction::ToString(bool compact_operands) const {
// Do not show large constants.
operands = "{...}";
}
+ } else if (opcode() == HloOpcode::kParameter) {
+ operands = Printf("%lld", parameter_number_);
} else {
tensorflow::gtl::ArraySlice<HloInstruction*> slice(operands_);
const int64 kMaxOperandsToShowIfCompact = 4;
@@ -1482,6 +1485,7 @@ string HloInstruction::ToString(bool compact_operands) const {
!metadata_.source_file().empty()) {
StrAppend(&extra, " # metadata=", metadata_.ShortDebugString());
}
+
return Printf("%s = %s %s(%s)%s", name().c_str(),
ShapeUtil::HumanStringWithLayout(shape()).c_str(),
ExtendedOpcodeStr().c_str(), operands.c_str(), extra.c_str());
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index f57b27f605..4647c2ca85 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -421,6 +421,11 @@ class HloInstruction {
return parameter_name_;
}
+ void set_parameter_name(const string& str) {
+ CHECK_EQ(HloOpcode::kParameter, opcode_);
+ parameter_name_ = str;
+ }
+
// Returns the dimension sizes or numbers associated with this instruction.
//
// Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 4ff83c2de7..1465d1cacd 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -62,7 +62,7 @@ TEST(HloMatchersTest, Test) {
"doesn't match expected:\n\t"
"multiply(add, add), \n"
"operand 0:\n\t"
- "%param = f32[1]{0} parameter()\n"
+ "%param = f32[1]{0} parameter(0)\n"
"doesn't match expected:\n\t"
"add"));
}