diff options
author | 2018-04-02 15:27:24 -0700 | |
---|---|---|
committer | 2018-04-02 15:29:48 -0700 | |
commit | 11c0faed23ec32c3f1532f5154dd3c7bb38847d5 (patch) | |
tree | d22b155e8d8ba565ddcac8e458b386f318fd18a6 | |
parent | 5bb819f64deaa9a641abd95b17c00a843dcb3ce8 (diff) |
[XLA] Set trace for the operand of a trace instruction when creating the instruction directly or creating from proto. Also implement XlaBuidler::Trace.
PiperOrigin-RevId: 191357376
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/user_computation.cc | 1 |
3 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index 04091ecb11..ec2362179e 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -543,7 +543,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand, } void XlaBuilder::Trace(const string& tag, const XlaOp& operand) { - UnimplementedOp(); + NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = ShapeUtil::MakeNil(); + *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto(); + return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand}); + }); } XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true, diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index a2a2c1e615..fcf9ebf5f7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -98,6 +98,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( } } + if (instruction->opcode() == HloOpcode::kTrace) { + TF_RET_CHECK(instruction->operands().size() == 1) + << "Trace instruction should have 1 operand but sees " + << instruction->operands().size(); + instruction->mutable_operand(0)->set_tracing(instruction.get()); + } + TF_RET_CHECK(!proto.name().empty()); instruction->name_ = proto.name(); @@ -170,6 +177,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil())); instruction->operands_.push_back(operand); instruction->literal_ = Literal::CreateR1U8(tag); + operand->set_tracing(instruction.get()); return instruction; } diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fcdb2e01fb..532f7fd5bf 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit( HloInstruction* operand = lookup_instruction(trace_request.operand()); hlo_instruction = add_instruction( HloInstruction::CreateTrace(trace_request.tag(), operand)); - operand->set_tracing(hlo_instruction); break; } |