aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Tong Shen <endlessroad@google.com>2018-08-21 15:31:34 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 15:35:20 -0700
commitb1b2cb38f2fcf089f8fc238d3c72cf9507887ed3 (patch)
treed93dab35a3831cfb6e9d5ea5804b58672562e874 /tensorflow/compiler/xla
parent0f02f05913e03889bbcb85e71a6d005a8519bfb9 (diff)
Remove HostCompute HLO.
Now for host compute, we just emit SendToHost & RecvFromHost pairs, and use token to ensure dependency. PiperOrigin-RevId: 209671416
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc20
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h25
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_graph_dumper.cc1
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc35
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h27
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.h1
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.h1
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc1
16 files changed, 0 insertions, 168 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 54fe87a7a8..428ab9d23a 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -196,7 +196,6 @@ void XlaBuilder::IsConstantVisitor(const int64 op_handle,
// TODO(b/33009255): Implmement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
- case HloOpcode::kHostCompute:
case HloOpcode::kCall:
// TODO(b/32495713): We aren't checking the to_apply computation itself,
// so we conservatively say that computations containing the Call op
@@ -1278,18 +1277,6 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
});
}
-XlaOp XlaBuilder::HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name,
- int64 cost_estimate_ns, const Shape& shape) {
- return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
- HloInstructionProto instr;
- *instr.mutable_shape() = shape;
- instr.set_channel_name(channel_name);
- instr.set_cost_estimate_ns(cost_estimate_ns);
- return AddInstruction(std::move(instr), HloOpcode::kHostCompute, operands);
- });
-}
-
XlaOp XlaBuilder::Complex(
const XlaOp& real, const XlaOp& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
@@ -2643,13 +2630,6 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
return builder->CustomCall(call_target_name, operands, shape);
}
-XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape) {
- return builder->HostCompute(operands, channel_name, cost_estimate_ns, shape);
-}
-
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return real.builder()->Complex(real, imag, broadcast_dimensions);
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 469d5048b2..313635ae63 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -586,16 +586,6 @@ class XlaBuilder {
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape);
- // Enqueues a pseudo-op to represent host-side computation data-dependencies.
- // During code generation, host send and receive operations will be generated
- // to transfer |operands| to the host and a single result of |shape| back to
- // the device. Host send/recv operations are emitted using |channel_name|.
- // Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
- // instruction scheduling.
- XlaOp HostCompute(tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
-
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
@@ -1201,10 +1191,6 @@ class XlaBuilder {
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape);
- friend XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1737,17 +1723,6 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
tensorflow::gtl::ArraySlice<XlaOp> operands,
const Shape& shape);
-// Enqueues a pseudo-op to represent host-side computation data-dependencies.
-// During code generation, host send and receive operations will be generated
-// to transfer |operands| to the host and a single result of |shape| back to
-// the device. Host send/recv operations are emitted using |channel_name|.
-// Dataflow dependencies and the |cost_estimate_ns| field may be used in HLO
-// instruction scheduling.
-XlaOp HostCompute(XlaBuilder* builder,
- tensorflow::gtl::ArraySlice<XlaOp> operands,
- const string& channel_name, int64 cost_estimate_ns,
- const Shape& shape);
-
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 86d57581f8..690b5df514 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -208,7 +208,6 @@ class DfsHloVisitorBase {
virtual Status HandleInfeed(HloInstructionPtr hlo) = 0;
virtual Status HandleOutfeed(HloInstructionPtr hlo) = 0;
- virtual Status HandleHostCompute(HloInstructionPtr hlo) = 0;
virtual Status HandleRng(HloInstructionPtr hlo) = 0;
virtual Status HandleReverse(HloInstructionPtr hlo) = 0;
virtual Status HandleSort(HloInstructionPtr hlo) = 0;
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
index 617a5a2eb4..20c6bafe7c 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h
@@ -106,9 +106,6 @@ class DfsHloVisitorWithDefaultBase
Status HandleOutfeed(HloInstructionPtr outfeed) override {
return DefaultAction(outfeed);
}
- Status HandleHostCompute(HloInstructionPtr host_compute) override {
- return DefaultAction(host_compute);
- }
Status HandleReverse(HloInstructionPtr reverse) override {
return DefaultAction(reverse);
}
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 1bbb0ff08e..3e68f59bd9 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -258,10 +258,6 @@ Status HloCostAnalysis::HandleOutfeed(const HloInstruction*) {
return Status::OK();
}
-Status HloCostAnalysis::HandleHostCompute(const HloInstruction*) {
- return Status::OK();
-}
-
Status HloCostAnalysis::HandleMap(const HloInstruction* map) {
// Compute properties of the mapped function.
TF_ASSIGN_OR_RETURN(const Properties sub_properties,
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.h b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
index 193a04bea0..1bf1c4a315 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.h
@@ -74,7 +74,6 @@ class HloCostAnalysis : public ConstDfsHloVisitor {
Status HandleAllToAll(const HloInstruction* hlo) override;
Status HandleInfeed(const HloInstruction* infeed) override;
Status HandleOutfeed(const HloInstruction* outfeed) override;
- Status HandleHostCompute(const HloInstruction* host_compute) override;
Status HandleRng(const HloInstruction* random) override;
Status HandleReverse(const HloInstruction* reverse) override;
Status HandleSort(const HloInstruction* sort) override;
diff --git a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
index 1efa6eb5bd..a4ea21c692 100644
--- a/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
+++ b/tensorflow/compiler/xla/service/hlo_graph_dumper.cc
@@ -1059,7 +1059,6 @@ ColorScheme HloDotDumper::GetInstructionColor(const HloInstruction* instr) {
case HloOpcode::kCall:
case HloOpcode::kConditional:
case HloOpcode::kCustomCall:
- case HloOpcode::kHostCompute:
case HloOpcode::kWhile:
return kDarkGreen;
case HloOpcode::kConstant:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index a050459adf..cb2264d08d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -363,11 +363,6 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.convolution_dimension_numbers());
}
break;
- case HloOpcode::kHostCompute:
- instruction =
- CreateHostCompute(proto.shape(), all_operands(), proto.channel_name(),
- proto.cost_estimate_ns());
- break;
case HloOpcode::kPad:
TF_RET_CHECK(proto.operand_ids_size() == 2)
<< "Pad instruction should have 2 operands but sees "
@@ -1036,7 +1031,6 @@ bool HloInstruction::HasSideEffectNoRecurse() const {
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
- case HloOpcode::kHostCompute:
return true;
case HloOpcode::kCrossReplicaSum:
return all_reduce_id().has_value();
@@ -1077,13 +1071,6 @@ bool HloInstruction::HasSideEffect() const {
custom_call_target);
}
-/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns) {
- return absl::make_unique<HloHostComputeInstruction>(
- shape, operands, channel_name, cost_estimate_ns);
-}
-
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
tensorflow::gtl::ArraySlice<HloInstruction*> elements) {
std::vector<Shape> element_shapes;
@@ -1171,7 +1158,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kSort:
@@ -1637,7 +1623,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCustomCall:
case HloOpcode::kReduceWindow:
case HloOpcode::kSelectAndScatter:
- case HloOpcode::kHostCompute:
case HloOpcode::kPad:
case HloOpcode::kDynamicSlice:
case HloOpcode::kGather:
@@ -2348,8 +2333,6 @@ Status HloInstruction::Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor) {
return visitor->HandleInfeed(this);
case HloOpcode::kOutfeed:
return visitor->HandleOutfeed(this);
- case HloOpcode::kHostCompute:
- return visitor->HandleHostCompute(this);
case HloOpcode::kRng:
return visitor->HandleRng(this);
case HloOpcode::kWhile:
@@ -3223,10 +3206,6 @@ const string& HloInstruction::custom_call_target() const {
return Cast<HloCustomCallInstruction>(this)->custom_call_target();
}
-const string& HloInstruction::channel_name() const {
- return Cast<HloHostComputeInstruction>(this)->channel_name();
-}
-
const PaddingConfig& HloInstruction::padding_config() const {
return Cast<HloPadInstruction>(this)->padding_config();
}
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index a392252774..41bb40b7bd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -709,12 +709,6 @@ class HloInstruction {
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target);
- // Creates a HostCompute instruction, which records host-side control and
- // data dependencies for use in instruction scheduling.
- static std::unique_ptr<HloInstruction> CreateHostCompute(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
-
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
@@ -1476,9 +1470,6 @@ class HloInstruction {
// Delegates to HloCustomCallInstruction::custom_call_target.
const string& custom_call_target() const;
- // Delegates to HloHostComputeInstruction::channel_name.
- const string& channel_name() const;
-
// Delegates to HloPadInstruction::padding_config.
const PaddingConfig& padding_config() const;
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 79a5e7481d..e91cabbb72 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1856,41 +1856,6 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
return std::move(cloned);
}
-HloHostComputeInstruction::HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns)
- : HloInstruction(HloOpcode::kHostCompute, shape),
- channel_name_(channel_name.begin(), channel_name.end()),
- cost_estimate_ns_(cost_estimate_ns) {
- for (auto operand : operands) {
- AppendOperand(operand);
- }
-}
-
-HloInstructionProto HloHostComputeInstruction::ToProto() const {
- HloInstructionProto proto = HloInstruction::ToProto();
- proto.set_channel_name(channel_name_);
- proto.set_cost_estimate_ns(cost_estimate_ns_);
- return proto;
-}
-
-bool HloHostComputeInstruction::IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const {
- // Not yet supported.
- return false;
-}
-
-std::unique_ptr<HloInstruction>
-HloHostComputeInstruction::CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const {
- return absl::make_unique<HloHostComputeInstruction>(
- shape, new_operands, channel_name_, cost_estimate_ns_);
-}
-
HloPadInstruction::HloPadInstruction(const Shape& shape,
HloInstruction* operand,
HloInstruction* padding_value,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 19b69c2171..1152fa83ed 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1118,33 +1118,6 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
};
-class HloHostComputeInstruction : public HloInstruction {
- public:
- explicit HloHostComputeInstruction(
- const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
- tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
- // Returns the channel name associated with the instruction. The name is
- // used to identify host Send/Recv operations.
- const string& channel_name() const { return channel_name_; }
- // Returns a serialized representation of this instruction.
- HloInstructionProto ToProto() const override;
-
- private:
- bool IdenticalSlowPath(
- const HloInstruction& other,
- const std::function<bool(const HloComputation*, const HloComputation*)>&
- eq_computations) const override;
- // Implementation for non-common logic of CloneWithNewOperands.
- std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
- const Shape& shape,
- tensorflow::gtl::ArraySlice<HloInstruction*> new_operands,
- HloCloneContext* context) const override;
- // Name to use for host send/recv channels.
- string channel_name_;
- // Estimate of the duration of a host computation in nanoseconds.
- int64 cost_estimate_ns_ = 0;
-};
-
class HloPadInstruction : public HloInstruction {
public:
explicit HloPadInstruction(const Shape& shape, HloInstruction* operand,
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.h b/tensorflow/compiler/xla/service/hlo_opcode.h
index 0e0d96ab09..b8f2a21ff9 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.h
+++ b/tensorflow/compiler/xla/service/hlo_opcode.h
@@ -85,7 +85,6 @@ namespace xla {
V(kAfterAll, "after-all", kHloOpcodeIsVariadic) \
V(kGetTupleElement, "get-tuple-element") \
V(kGt, "greater-than", kHloOpcodeIsComparison) \
- V(kHostCompute, "host-compute") \
V(kImag, "imag") \
V(kInfeed, "infeed") \
V(kIota, "iota") \
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 3768da8a73..aafd0e4efd 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1180,20 +1180,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
break;
}
- case HloOpcode::kHostCompute: {
- optional<string> channel_name;
- optional<tensorflow::int64> cost_estimate_ns;
- attrs["channel_name"] = {/*required=*/true, AttrTy::kString,
- &channel_name};
- attrs["cost_estimate_ns"] = {/*required=*/true, AttrTy::kInt64,
- &cost_estimate_ns};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateHostCompute(
- shape, operands, *channel_name, *cost_estimate_ns));
- break;
- }
case HloOpcode::kDot: {
optional<std::vector<tensorflow::int64>> lhs_contracting_dims;
attrs["lhs_contracting_dims"] = {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index ac1a663633..7acf58e252 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -183,10 +183,6 @@ Status ShapeVerifier::HandleOutfeed(HloInstruction* instruction) {
return CheckShape(outfeed, ShapeUtil::MakeTokenShape());
}
-Status ShapeVerifier::HandleHostCompute(HloInstruction*) {
- return Status::OK();
-}
-
bool ShapeVerifier::HasCompatibleElementTypes(const Shape& shape_0,
const Shape& shape_1,
const Shape& result_shape) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.h b/tensorflow/compiler/xla/service/hlo_verifier.h
index 9e54b54b26..523bf4d70c 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.h
+++ b/tensorflow/compiler/xla/service/hlo_verifier.h
@@ -64,7 +64,6 @@ class ShapeVerifier : public DfsHloVisitor {
Status HandleFusion(HloInstruction*) override;
Status HandleCall(HloInstruction* call) override;
Status HandleCustomCall(HloInstruction*) override;
- Status HandleHostCompute(HloInstruction*) override;
Status HandleSlice(HloInstruction* slice) override;
Status HandleDynamicSlice(HloInstruction* dynamic_slice) override;
Status HandleDynamicUpdateSlice(
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index 2fd2214806..be59ce8281 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -131,7 +131,6 @@ bool IsAlwaysDuplicable(const HloInstruction& instruction) {
case HloOpcode::kFft:
case HloOpcode::kFusion:
case HloOpcode::kGather:
- case HloOpcode::kHostCompute:
case HloOpcode::kLog:
case HloOpcode::kLog1p:
case HloOpcode::kMap: