diff options
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 5 |
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index c3c824a231..7ccdc2ded2 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1074,7 +1074,13 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.set_channel_name(channel_name); + instr.set_cost_estimate_ns(cost_estimate_ns); + return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); + }); } XlaOp XlaBuilder::Complex( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b446c6547..8fd7f8945c 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -135,6 +135,10 @@ message HloInstructionProto { xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_window_bounds = 34; + // Compute Host. + string channel_name = 41; + int64 cost_estimate_ns = 42; + // The id of this instruction. int64 id = 35; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3629106a25..a986bbd511 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -167,6 +167,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->gather_window_bounds_.push_back(bound); } + instruction->channel_name_ = proto.channel_name(); + instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + return std::move(instruction); } @@ -2430,6 +2433,8 @@ HloInstructionProto HloInstruction::ToProto() const { for (int64 bound : gather_window_bounds_) { proto.add_gather_window_bounds(bound); } + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); return proto; } |