aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc14
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h10
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
-rw-r--r--tensorflow/compiler/xla/service/service.cc5
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc52
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc14
-rw-r--r--tensorflow/compiler/xla/xla_data.proto17
18 files changed, 165 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index b1dcad6a49..e6dfe0aefb 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -789,6 +789,20 @@ ComputationDataHandle ComputationBuilder::CustomCall(
return RunOpAndParseResponse(&op_request);
}
+ComputationDataHandle ComputationBuilder::HostCompute(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const string& channel_name, int64 cost_estimate_ns, const Shape& shape) {
+ OpRequest op_request;
+ HostComputeRequest* request = op_request.mutable_host_compute_request();
+ for (const ComputationDataHandle& operand : operands) {
+ *request->add_operands() = operand;
+ }
+ *request->mutable_shape() = shape;
+ request->set_channel_name(channel_name);
+ request->set_cost_estimate_ns(cost_estimate_ns);
+ return RunOpAndParseResponse(&op_request);
+}
+
ComputationDataHandle ComputationBuilder::Complex(
const ComputationDataHandle& real, const ComputationDataHandle& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index 7cae91e9e0..aa2622174d 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -446,6 +446,16 @@ class ComputationBuilder {
tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
const Shape& shape);
+ // Enqueues a pseudo-op to represent host-side computation data-dependencies.
+ // During code generation, host send and receive operations will be generated
+ // to transfer |operands| to the host and a single result of |shape| back to
+ // the device. Host send/recv operations are emitted using |channel_name|.
+ // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
+ // instruction scheduling.
+ ComputationDataHandle HostCompute(
+ tensorflow::gtl::ArraySlice<ComputationDataHandle> operands,
+ const string& channel_name, int64 cost_estimate_ns, const Shape& shape);
+
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index a803b3171f..5b09e4931e 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -190,6 +190,7 @@ class DfsHloVisitorBase {
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
+ virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
virtual Status HandleSort(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 170adb3d24..ffc4f3bb79 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -103,6 +103,9 @@ class DfsHloVisitorWithDefaultBase
Status HandleOutfeed(HloInstructionPtr outfeed) override {
return DefaultAction(outfeed);
}
+ Status HandleHostCompute(HloInstructionPtr host_compute) override {
+ return DefaultAction(host_compute);
+ }
Status HandleReverse(HloInstructionPtr reverse) override {
return DefaultAction(reverse);
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 9cd5a1e2b7..6a4651d83f 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -229,6 +229,10 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
return Status::OK();
}
+Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
+ return Status::OK();
+}
+
Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
// Compute properties of the mapped function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index e5783539e5..af52ea06ca 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -71,6 +71,7 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleCrossReplicaSum(const HloInstruction* crs) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
+ Status HandleHostCompute(const HloInstruction* host_compute) override;
Status HandleRng(const HloInstruction* random) override;
Status HandleReverse(const HloInstruction* reverse) override;
Status HandleSort(const HloInstruction* sort) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 44fcd36370..9b0e2fd7d6 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -988,6 +988,7 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
+ case HloOpcode::kHostCompute:
case HloOpcode::kWhile:
return kDarkGreen;
case HloOpcode::kConstant:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index d719ff857d..0d925ad00d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -1094,6 +1094,7 @@ bool HloInstruction::HasSideEffect() const {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
+ case HloOpcode::kHostCompute:
return true;
default: {
// Check if any of the called computations has a side effect.
@@ -1131,6 +1132,19 @@ bool HloInstruction::HasSideEffect() const {
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
+ std::unique_ptr<HloInstruction> instruction =
+ WrapUnique(new HloInstruction(HloOpcode::kHostCompute, shape));
+ for (auto operand : operands) {
+ instruction->AppendOperand(operand);
+ }
+ instruction->channel_name_ = channel_name.ToString();
+ instruction->cost_estimate_ns_ = cost_estimate_ns;
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
std::vector<Shape> element_shapes;
@@ -1222,6 +1236,10 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kCustomCall:
clone = CreateCustomCall(shape, new_operands, custom_call_target_);
break;
+ case HloOpcode::kHostCompute:
+ clone = CreateHostCompute(shape, new_operands, channel_name_,
+ cost_estimate_ns_);
+ break;
case HloOpcode::kConcatenate:
clone = CreateConcatenate(shape, new_operands, dimensions(0));
break;
@@ -1792,6 +1810,7 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kRecvDone:
case HloOpcode::kSend:
case HloOpcode::kSendDone:
+ case HloOpcode::kHostCompute:
return false;
}
}
@@ -2577,6 +2596,8 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
+ case HloOpcode::kHostCompute:
+ return visitor->HandleHostCompute(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kWhile:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 3cf43f0adf..e898a83739 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -475,6 +475,12 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target);
+ // Creates a HostCompute instruction, which records host-side control and
+ // data dependencies for use in instruction scheduling.
+ static std::unique_ptr<HloInstruction> CreateHostCompute(
+ const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
+ tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
+
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
@@ -1398,6 +1404,12 @@ class HloInstruction {
// Name of a global symbol to call, only present for kCustomCall.
string custom_call_target_;
+ // Name to use for host send/recv channels, only present for kHostCompute.
+ string channel_name_;
+
+ // Estimate of the duration of a host computation in nanoseconds.
+ int64 cost_estimate_ns_;
+
// Computations called by this instruction.
std::vector<HloComputation*> called_computations_;
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 3d64523a79..088dd15dbf 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -79,6 +79,7 @@ namespace xla {
V(kGe, "greater-than-or-equal-to", kHloOpcodeIsComparison) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
+ V(kHostCompute, "host-compute") \
V(kImag, "imag") \
V(kInfeed, "infeed") \
V(kIsFinite, "is-finite") \
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index e2b3bb9d71..f3378309c2 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -125,6 +125,10 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* outfeed) {
return CheckShape(outfeed, ShapeUtil::MakeNil());
}
+Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
+ return tensorflow::Status::OK();
+}
+
Status ShapeVerifier::HandleRng(HloInstruction*) {
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 7eccf834bb..f9f898c236 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -60,6 +60,7 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFusion(HloInstruction*) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction*) override;
+ Status HandleHostCompute(HloInstruction*) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 90e1f0acdc..f08d809d79 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -102,6 +102,7 @@ namespace xla {
case HloOpcode::kExp:
case HloOpcode::kFft:
case HloOpcode::kFusion:
+ case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kMap:
case HloOpcode::kParameter:
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index 98dfc89867..95c853b5c4 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -44,6 +44,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
@@ -1456,6 +1457,10 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status =
computation->AddOutfeedInstruction(arg->outfeed_request());
break;
+ case OpRequest::kHostComputeRequest:
+ handle_status =
+ computation->AddHostComputeInstruction(arg->host_compute_request());
+ break;
case OpRequest::kMapRequest: {
TF_ASSIGN_OR_RETURN(
UserComputation * to_apply,
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index fead9b9236..d42cb6cdf3 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -1276,6 +1276,28 @@ StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddHostComputeInstruction(
+ const HostComputeRequest& host_compute_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ for (const ComputationDataHandle& handle : host_compute_request.operands()) {
+ TF_RETURN_IF_ERROR(LookUpRequest(handle).status());
+ }
+
+ ComputationDataHandle handle = CreateComputationDataHandle();
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = host_compute_request.shape();
+ *request.mutable_request()->mutable_host_compute_request() =
+ host_compute_request;
+
+ VLOG(1) << "AddHostComputeInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << host_compute_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction(
const DotRequest& dot_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1713,6 +1735,11 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kHostComputeRequest: {
+ *is_functional = false;
+ break;
+ }
+
case OpRequest::kCallRequest: {
const CallRequest& call_request = request.request().call_request();
for (const ComputationDataHandle& handle : call_request.operands()) {
@@ -2643,6 +2670,15 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kHostComputeRequest: {
+ const HostComputeRequest& hc_request =
+ request.request().host_compute_request();
+ for (const ComputationDataHandle& operand : hc_request.operands()) {
+ apply(operand);
+ }
+ break;
+ }
+
case OpRequest::kDotRequest: {
const DotRequest& dot_request = request.request().dot_request();
apply(dot_request.rhs());
@@ -3299,6 +3335,22 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kHostComputeRequest: {
+ const HostComputeRequest& host_compute_request =
+ request.request().host_compute_request();
+ std::vector<HloInstruction*> operands;
+ for (const ComputationDataHandle& operand :
+ host_compute_request.operands()) {
+ operands.push_back(lookup_instruction(operand));
+ }
+ auto output_shape = host_compute_request.shape();
+ auto channel_name = host_compute_request.channel_name();
+ auto cost_estimate_ns = host_compute_request.cost_estimate_ns();
+ hlo_instruction = add_instruction(HloInstruction::CreateHostCompute(
+ output_shape, operands, channel_name, cost_estimate_ns));
+ break;
+ }
+
case OpRequest::kUnaryOpRequest: {
const UnaryOpRequest& unary_op_request =
request.request().unary_op_request();
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 54bb24d6d7..81a72583f7 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -149,6 +149,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddOutfeedInstruction(
const OutfeedRequest& outfeed_request);
+ // Enqueues a host compute instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddHostComputeInstruction(
+ const HostComputeRequest& host_compute_request);
+
// Enqueues a call instruction onto this user computation.
StatusOr<ComputationDataHandle> AddCallInstruction(
const CallRequest& call_request,
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 89def5d561..5dd5780835 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -994,6 +994,20 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
shape, operands, *custom_call_target));
break;
}
+ case HloOpcode::kHostCompute: {
+ optional<string> channel_name;
+ optional<int64> cost_estimate_ns;
+ attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
+ &channel_name};
+ attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
+ &cost_estimate_ns};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
+ shape, operands, *channel_name, *cost_estimate_ns));
+ break;
+ }
case HloOpcode::kDot: {
optional<std::vector<int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 3aea021753..4fa5d28211 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -519,6 +519,20 @@ message CustomCallRequest {
Shape shape = 4;
}
+message HostComputeRequest {
+ // Operand to the HostCompute. Supports tuple.
+ repeated ComputationDataHandle operands = 1;
+
+ // Name used to identify HostSend/Recv channels.
+ string channel_name = 2;
+
+ // Cost estimate in nanoseconds.
+ int64 cost_estimate_ns = 3;
+
+ // The shape of any data returned by host.
+ Shape shape = 4;
+}
+
message DotDimensionNumbers {
// The dimension numbers that represent the 'lhs' contracting dimensions.
repeated int64 lhs_contracting_dimensions = 1;
@@ -957,7 +971,8 @@ message OpRequest {
FftRequest fft_request = 41;
ConvertRequest bitcast_convert_request = 42;
ConditionalRequest conditional_request = 44;
- // Next: 45
+ HostComputeRequest host_compute_request = 45;
+ // Next: 46
}
}