diff options
18 files changed, 165 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc index b1dcad6a49..e6dfe0aefb 100644 --- a/tensorflow/compiler/xla/client/computation_builder.cc +++ b/tensorflow/compiler/xla/client/computation_builder.cc @@ -789,6 +789,20 @@ ComputationDataHandle ComputationBuilder::CustomCall( return RunOpAndParseResponse(&op_request); } +ComputationDataHandle ComputationBuilder::HostCompute( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape) { + OpRequest op_request; + HostComputeRequest* request = op_request.mutable_host_compute_request(); + for (const ComputationDataHandle& operand : operands) { + *request->add_operands() = operand; + } + *request->mutable_shape() = shape; + request->set_channel_name(channel_name); + request->set_cost_estimate_ns(cost_estimate_ns); + return RunOpAndParseResponse(&op_request); +} + ComputationDataHandle ComputationBuilder::Complex( const ComputationDataHandle& real, const ComputationDataHandle& imag, tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) { diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h index 7cae91e9e0..aa2622174d 100644 --- a/tensorflow/compiler/xla/client/computation_builder.h +++ b/tensorflow/compiler/xla/client/computation_builder.h @@ -446,6 +446,16 @@ class ComputationBuilder { tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, const Shape& shape); + // Enqueues a pseudo-op to represent host-side computation data-dependencies. + // During code generation, host send and receive operations will be generated + // to transfer |operands| to the host and a single result of |shape| back to + // the device. Host send/recv operations are emitted using |channel_name|. + // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO + // instruction scheduling. + ComputationDataHandle HostCompute( + tensorflow::gtl::ArraySlice<ComputationDataHandle> operands, + const string& channel_name, int64 cost_estimate_ns, const Shape& shape); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h index a803b3171f..5b09e4931e 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h @@ -190,6 +190,7 @@ class DfsHloVisitorBase { virtual Status HandleInfeed(HloInstructionPtr hlo) = 0; virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0; + virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0; virtual Status HandleRng(HloInstructionPtr hlo) = 0; virtual Status HandleReverse(HloInstructionPtr hlo) = 0; virtual Status HandleSort(HloInstructionPtr hlo) = 0; diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h index 170adb3d24..ffc4f3bb79 100644 --- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h +++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h @@ -103,6 +103,9 @@ class DfsHloVisitorWithDefaultBase Status HandleOutfeed(HloInstructionPtr outfeed) override { return DefaultAction(outfeed); } + Status HandleHostCompute(HloInstructionPtr host_compute) override { + return DefaultAction(host_compute); + } Status HandleReverse(HloInstructionPtr reverse) override { return DefaultAction(reverse); } diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc index 9cd5a1e2b7..6a4651d83f 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc @@ -229,6 +229,10 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) { return Status::OK(); } +Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) { + return Status::OK(); +} + Status HloCostAnalysis::HandleMap(const HloInstruction* map) { // Compute properties of the mapped function. TF_ASSIGN_OR_RETURN(const Properties sub_properties, diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h index e5783539e5..af52ea06ca 100644 --- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h +++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h @@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor { Status HandleCrossReplicaSum(const HloInstruction* crs) override; Status HandleInfeed(const HloInstruction* infeed) override; Status HandleOutfeed(const HloInstruction* outfeed) override; + Status HandleHostCompute(const HloInstruction* host_compute) override; Status HandleRng(const HloInstruction* random) override; Status HandleReverse(const HloInstruction* reverse) override; Status HandleSort(const HloInstruction* sort) override; diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc index 44fcd36370..9b0e2fd7d6 100644 --- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc +++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc @@ -988,6 +988,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) { case HloOpcode::kCall: case HloOpcode::kConditional: case HloOpcode::kCustomCall: + case HloOpcode::kHostCompute: case HloOpcode::kWhile: return kDarkGreen; case HloOpcode::kConstant: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index d719ff857d..0d925ad00d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1094,6 +1094,7 @@ bool HloInstruction::HasSideEffect() const { case HloOpcode::kInfeed: case HloOpcode::kOutfeed: case HloOpcode::kTrace: + case HloOpcode::kHostCompute: return true; default: { // Check if any of the called computations has a side effect. @@ -1131,6 +1132,19 @@ bool HloInstruction::HasSideEffect() const { return instruction; } +/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) { + std::unique_ptr<HloInstruction> instruction = + WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape)); + for (auto operand : operands) { + instruction->AppendOperand(operand); + } + instruction->channel_name_ = channel_name.ToString(); + instruction->cost_estimate_ns_ = cost_estimate_ns; + return instruction; +} + /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple( tensorflow::gtl::ArraySlice<HloInstruction*> elements) { std::vector<Shape> element_shapes; @@ -1222,6 +1236,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands( case HloOpcode::kCustomCall: clone = CreateCustomCall(shape, new_operands, custom_call_target_); break; + case HloOpcode::kHostCompute: + clone = CreateHostCompute(shape, new_operands, channel_name_, + cost_estimate_ns_); + break; case HloOpcode::kConcatenate: clone = CreateConcatenate(shape, new_operands, dimensions(0)); break; @@ -1792,6 +1810,7 @@ bool HloInstruction::IdenticalSlowPath( case HloOpcode::kRecvDone: case HloOpcode::kSend: case HloOpcode::kSendDone: + case HloOpcode::kHostCompute: return false; } } @@ -2577,6 +2596,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) { return visitor->HandleInfeed(this); case HloOpcode::kOutfeed: return visitor->HandleOutfeed(this); + case HloOpcode::kHostCompute: + return visitor->HandleHostCompute(this); case HloOpcode::kRng: return visitor->HandleRng(this); case HloOpcode::kWhile: diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 3cf43f0adf..e898a83739 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -475,6 +475,12 @@ class HloInstruction { const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, tensorflow::StringPiece custom_call_target); + // Creates a HostCompute instruction, which records host-side control and + // data dependencies for use in instruction scheduling. + static std::unique_ptr<HloInstruction> CreateHostCompute( + const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands, + tensorflow::StringPiece channel_name, const int64 cost_estimate_ns); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr<HloInstruction> CreateTuple( @@ -1398,6 +1404,12 @@ class HloInstruction { // Name of a global symbol to call, only present for kCustomCall. string custom_call_target_; + // Name to use for host send/recv channels, only present for kHostCompute. + string channel_name_; + + // Estimate of the duration of a host computation in nanoseconds. + int64 cost_estimate_ns_; + // Computations called by this instruction. std::vector<HloComputation*> called_computations_; diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h index 3d64523a79..088dd15dbf 100644 --- a/tensorflow/compiler/xla/service/hlo_opcode.h +++ b/tensorflow/compiler/xla/service/hlo_opcode.h @@ -79,6 +79,7 @@ namespace xla { V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \ V(kGetTupleElement, "get-tuple-element") \ V(kGt, "greater-than", kHloOpcodeIsComparison) \ + V(kHostCompute, "host-compute") \ V(kImag, "imag") \ V(kInfeed, "infeed") \ V(kIsFinite, "is-finite") \ diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index e2b3bb9d71..f3378309c2 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -125,6 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) { return CheckShape(outfeed, ShapeUtil::MakeNil()); } +Status ShapeVerifier::HandleHostCompute(HloInstruction*) { + return tensorflow::Status::OK(); +} + Status ShapeVerifier::HandleRng(HloInstruction*) { return tensorflow::Status::OK(); } diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h index 7eccf834bb..f9f898c236 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.h +++ b/tensorflow/compiler/xla/service/hlo_verifier.h @@ -60,6 +60,7 @@ class ShapeVerifier : public DfsHloVisitor { Status HandleFusion(HloInstruction*) override; Status HandleCall(HloInstruction* call) override; Status HandleCustomCall(HloInstruction*) override; + Status HandleHostCompute(HloInstruction*) override; Status HandleSlice(HloInstruction* slice) override; Status HandleDynamicSlice(HloInstruction* dynamic_slice) override; Status HandleDynamicUpdateSlice( diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc index 90e1f0acdc..f08d809d79 100644 --- a/tensorflow/compiler/xla/service/instruction_fusion.cc +++ b/tensorflow/compiler/xla/service/instruction_fusion.cc @@ -102,6 +102,7 @@ namespace xla { case HloOpcode::kExp: case HloOpcode::kFft: case HloOpcode::kFusion: + case HloOpcode::kHostCompute: case HloOpcode::kLog: case HloOpcode::kMap: case HloOpcode::kParameter: diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc index 98dfc89867..95c853b5c4 100644 --- a/tensorflow/compiler/xla/service/service.cc +++ b/tensorflow/compiler/xla/service/service.cc @@ -44,6 +44,7 @@ limitations under the License. #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/util.h" +#include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/cleanup.h" #include "tensorflow/core/lib/strings/strcat.h" #include "tensorflow/core/lib/strings/stringprintf.h" @@ -1456,6 +1457,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) { handle_status = computation->AddOutfeedInstruction(arg->outfeed_request()); break; + case OpRequest::kHostComputeRequest: + handle_status = + computation->AddHostComputeInstruction(arg->host_compute_request()); + break; case OpRequest::kMapRequest: { TF_ASSIGN_OR_RETURN( UserComputation * to_apply, diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc index fead9b9236..d42cb6cdf3 100644 --- a/tensorflow/compiler/xla/service/user_computation.cc +++ b/tensorflow/compiler/xla/service/user_computation.cc @@ -1276,6 +1276,28 @@ StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction( return handle; } +StatusOr<ComputationDataHandle> UserComputation::AddHostComputeInstruction( + const HostComputeRequest& host_compute_request) { + tensorflow::mutex_lock lock(mutex_); + + for (const ComputationDataHandle& handle : host_compute_request.operands()) { + TF_RETURN_IF_ERROR(LookUpRequest(handle).status()); + } + + ComputationDataHandle handle = CreateComputationDataHandle(); + OperationRequest& request = + (*session_computation_.mutable_requests())[handle.handle()]; + *request.mutable_output_handle() = handle; + *request.mutable_output_shape() = host_compute_request.shape(); + *request.mutable_request()->mutable_host_compute_request() = + host_compute_request; + + VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal() + << "), data handle " << handle.handle() << ": " + << host_compute_request.ShortDebugString(); + return handle; +} + StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction( const DotRequest& dot_request) { tensorflow::mutex_lock lock(mutex_); @@ -1713,6 +1735,11 @@ void PureFunctionalVisitor(const SessionComputation& session_computation, break; } + case OpRequest::kHostComputeRequest: { + *is_functional = false; + break; + } + case OpRequest::kCallRequest: { const CallRequest& call_request = request.request().call_request(); for (const ComputationDataHandle& handle : call_request.operands()) { @@ -2643,6 +2670,15 @@ static void ForEachOperand( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& hc_request = + request.request().host_compute_request(); + for (const ComputationDataHandle& operand : hc_request.operands()) { + apply(operand); + } + break; + } + case OpRequest::kDotRequest: { const DotRequest& dot_request = request.request().dot_request(); apply(dot_request.rhs()); @@ -3299,6 +3335,22 @@ void ComputationLowerer::Visit( break; } + case OpRequest::kHostComputeRequest: { + const HostComputeRequest& host_compute_request = + request.request().host_compute_request(); + std::vector<HloInstruction*> operands; + for (const ComputationDataHandle& operand : + host_compute_request.operands()) { + operands.push_back(lookup_instruction(operand)); + } + auto output_shape = host_compute_request.shape(); + auto channel_name = host_compute_request.channel_name(); + auto cost_estimate_ns = host_compute_request.cost_estimate_ns(); + hlo_instruction = add_instruction(HloInstruction::CreateHostCompute( + output_shape, operands, channel_name, cost_estimate_ns)); + break; + } + case OpRequest::kUnaryOpRequest: { const UnaryOpRequest& unary_op_request = request.request().unary_op_request(); diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h index 54bb24d6d7..81a72583f7 100644 --- a/tensorflow/compiler/xla/service/user_computation.h +++ b/tensorflow/compiler/xla/service/user_computation.h @@ -149,6 +149,10 @@ class UserComputation { StatusOr<ComputationDataHandle> AddOutfeedInstruction( const OutfeedRequest& outfeed_request); + // Enqueues a host compute instruction onto this user computation. + StatusOr<ComputationDataHandle> AddHostComputeInstruction( + const HostComputeRequest& host_compute_request); + // Enqueues a call instruction onto this user computation. StatusOr<ComputationDataHandle> AddCallInstruction( const CallRequest& call_request, diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 89def5d561..5dd5780835 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -994,6 +994,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands, *custom_call_target)); break; } + case HloOpcode::kHostCompute: { + optional<string> channel_name; + optional<int64> cost_estimate_ns; + attrs["channel_name"] = {/*required=*/true, AttrTy::kString, + &channel_name}; + attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64, + &cost_estimate_ns}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction(HloInstruction::CreateHostCompute( + shape, operands, *channel_name, *cost_estimate_ns)); + break; + } case HloOpcode::kDot: { optional<std::vector<int64>> lhs_contracting_dims; attrs["lhs_contracting_dims"] = { diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto index 3aea021753..4fa5d28211 100644 --- a/tensorflow/compiler/xla/xla_data.proto +++ b/tensorflow/compiler/xla/xla_data.proto @@ -519,6 +519,20 @@ message CustomCallRequest { Shape shape = 4; } +message HostComputeRequest { + // Operand to the HostCompute. Supports tuple. + repeated ComputationDataHandle operands = 1; + + // Name used to identify HostSend/Recv channels. + string channel_name = 2; + + // Cost estimate in nanoseconds. + int64 cost_estimate_ns = 3; + + // The shape of any data returned by host. + Shape shape = 4; +} + message DotDimensionNumbers { // The dimension numbers that represent the 'lhs' contracting dimensions. repeated int64 lhs_contracting_dimensions = 1; @@ -957,7 +971,8 @@ message OpRequest { FftRequest fft_request = 41; ConvertRequest bitcast_convert_request = 42; ConditionalRequest conditional_request = 44; - // Next: 45 + HostComputeRequest host_compute_request = 45; + // Next: 46 } } |