aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-04-02 15:27:24 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-04-02 15:29:48 -0700
commit11c0faed23ec32c3f1532f5154dd3c7bb38847d5 (patch)
treed22b155e8d8ba565ddcac8e458b386f318fd18a6
parent5bb819f64deaa9a641abd95b17c00a843dcb3ce8 (diff)
[XLA] Set trace for the operand of a trace instruction when creating the instruction directly or creating from proto. Also implement XlaBuidler::Trace.
PiperOrigin-RevId: 191357376
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc8
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc1
3 files changed, 14 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index 04091ecb11..ec2362179e 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -543,7 +543,12 @@ XlaOp XlaBuilder::Collapse(const XlaOp& operand,
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
- UnimplementedOp();
+ NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = ShapeUtil::MakeNil();
+ *instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
+ return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
+ });
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a2a2c1e615..fcf9ebf5f7 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -98,6 +98,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
}
+ if (instruction->opcode() == HloOpcode::kTrace) {
+ TF_RET_CHECK(instruction->operands().size() == 1)
+ << "Trace instruction should have 1 operand but sees "
+ << instruction->operands().size();
+ instruction->mutable_operand(0)->set_tracing(instruction.get());
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->name_ = proto.name();
@@ -170,6 +177,7 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
WrapUnique(new HloInstruction(HloOpcode::kTrace, ShapeUtil::MakeNil()));
instruction->operands_.push_back(operand);
instruction->literal_ = Literal::CreateR1U8(tag);
+ operand->set_tracing(instruction.get());
return instruction;
}
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index fcdb2e01fb..532f7fd5bf 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -3491,7 +3491,6 @@ void ComputationLowerer::Visit(
HloInstruction* operand = lookup_instruction(trace_request.operand());
hlo_instruction = add_instruction(
HloInstruction::CreateTrace(trace_request.tag(), operand));
- operand->set_tracing(hlo_instruction);
break;
}