diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-04-10 22:29:13 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-04-10 22:31:48 -0700 |
commit | 785c484288913ed7989881483aefa3bee0cec015 (patch) | |
tree | fc6032fd18556c3b51782bff845e384be6fca034 | |
parent | f22655d09820f83881b8a2170eb51407956864d6 (diff) |
[XLA] Redesign: implement ComputeHost.
Also support convert from/to proto for ComputeHost.
PiperOrigin-RevId: 192403660
-rw-r--r-- | tensorflow/compiler/xla/client/xla_client/xla_builder.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 4 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 5 |
3 files changed, 16 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc index c3c824a231..7ccdc2ded2 100644 --- a/tensorflow/compiler/xla/client/xla_client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_client/xla_builder.cc @@ -1074,7 +1074,13 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands, const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { - return UnimplementedOp(); + return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> { + HloInstructionProto instr; + *instr.mutable_shape() = shape; + instr.set_channel_name(channel_name); + instr.set_cost_estimate_ns(cost_estimate_ns); + return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands); + }); } XlaOp XlaBuilder::Complex( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 0b446c6547..8fd7f8945c 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -135,6 +135,10 @@ message HloInstructionProto { xla.GatherDimensionNumbers gather_dimension_numbers = 33; repeated int64 gather_window_bounds = 34; + // Compute Host. + string channel_name = 41; + int64 cost_estimate_ns = 42; + // The id of this instruction. int64 id = 35; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 3629106a25..a986bbd511 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -167,6 +167,9 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( instruction->gather_window_bounds_.push_back(bound); } + instruction->channel_name_ = proto.channel_name(); + instruction->cost_estimate_ns_ = proto.cost_estimate_ns(); + return std::move(instruction); } @@ -2430,6 +2433,8 @@ HloInstructionProto HloInstruction::ToProto() const { for (int64 bound : gather_window_bounds_) { proto.add_gather_window_bounds(bound); } + proto.set_channel_name(channel_name_); + proto.set_cost_estimate_ns(cost_estimate_ns_); return proto; } |