aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_builder.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto4
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc5
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
index c3c824a231..7ccdc2ded2 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc
@@ -1074,7 +1074,13 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
const string& channel_name,
int64 cost_estimate_ns, const Shape& shape) {
- return UnimplementedOp();
+ return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
+ HloInstructionProto instr;
+ *instr.mutable_shape() = shape;
+ instr.set_channel_name(channel_name);
+ instr.set_cost_estimate_ns(cost_estimate_ns);
+ return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands);
+ });
}
XlaOp XlaBuilder::Complex(
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 0b446c6547..8fd7f8945c 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -135,6 +135,10 @@ message HloInstructionProto {
xla.GatherDimensionNumbers gather_dimension_numbers = 33;
repeated int64 gather_window_bounds = 34;
+ // Compute Host.
+ string channel_name = 41;
+ int64 cost_estimate_ns = 42;
+
// The id of this instruction.
int64 id = 35;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 3629106a25..a986bbd511 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -167,6 +167,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
instruction->gather_window_bounds_.push_back(bound);
}
+ instruction->channel_name_ = proto.channel_name();
+ instruction->cost_estimate_ns_ = proto.cost_estimate_ns();
+
return std::move(instruction);
}
@@ -2430,6 +2433,8 @@ HloInstructionProto HloInstruction::ToProto() const {
for (int64 bound : gather_window_bounds_) {
proto.add_gather_window_bounds(bound);
}
+ proto.set_channel_name(channel_name_);
+ proto.set_cost_estimate_ns(cost_estimate_ns_);
return proto;
}