diff options
author | Mark Heffernan <meheff@google.com> | 2018-09-27 10:53:36 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-09-27 10:56:27 -0700 |
commit | 6d41787c32483b28f8c93973f28d4d078ea0b37e (patch) | |
tree | 1b310e402a71a8b79b24f33080b034b75c4df32b /tensorflow/compiler | |
parent | 334244be6864dd1dbec9bc8bb4996cc286a8e3e3 (diff) |
Add opaque field to custom call.
The intent of this field is to enable more information to be encoded in the custom call and passed through to the backend.
PiperOrigin-RevId: 214800539
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.h | 24 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo.proto | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 9 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.cc | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instructions.h | 8 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 7 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 12 |
9 files changed, 67 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 95ff6432a5..5277de6a85 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1278,7 +1278,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) { XlaOp XlaBuilder::CustomCall(const string& call_target_name, absl::Span<const XlaOp> operands, - const Shape& shape) { + const Shape& shape, const string& opaque) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1289,6 +1289,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, } *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); + instr.set_custom_call_opaque(opaque); return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -2681,8 +2682,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, } XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape) { - return builder->CustomCall(call_target_name, operands, shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index d0c59fa6f2..1da6ddd318 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -577,11 +577,9 @@ class XlaBuilder { absl::Span<const XlaOp> operands); // Enqueues a custom call instruction onto the computation. - // During code generation, a call instruction is emitted which targets a - // symbol with the name |call_target_name|. The |operands| are passed to the - // call instruction. |shape| is the resultant shape. XlaOp CustomCall(const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1195,7 +1193,8 @@ class XlaBuilder { friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span<const XlaOp> operands); friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span<const int64> broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1717,12 +1716,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token, XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, absl::Span<const XlaOp> operands); -// Enqueues a custom call instruction onto the computation. -// During code generation, a call instruction is emitted which targets a -// symbol with the name |call_target_name|. The |operands| are passed to the -// call instruction. |shape| is the resultant shape. +// Enqueues a custom call instruction onto the computation. A custom call +// invokes code external to XLA. The |operands| are passed to the external code, +// and the external code is expected to produce a result of the given +// |shape|. The exact mechanism is backend-specific. For example, in the CPU +// backend, a call instruction is emitted which targets a symbol with the name +// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings, +// but |call_target_name| should be short as it may be used in labels. |opaque| +// can encode arbitrarily large amounts of information. XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, - absl::Span<const XlaOp> operands, const Shape& shape); + absl::Span<const XlaOp> operands, const Shape& shape, + const string& opaque = ""); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index b19ec12638..caaca16f71 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 53 +// Next ID: 54 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -124,9 +124,13 @@ message HloInstructionProto { // The string representation of the infeed configuration. bytes infeed_config = 27; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a external target (eg, global symbol) to call, only present for + // kCustomCall. string custom_call_target = 28; + // Opaque string, only present for kCustomCall. + string custom_call_opaque = 53; + // Shape of outfeed request. xla.Shape outfeed_shape = 29; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index f7ec854d80..23787dbc8a 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -379,7 +379,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto( break; case HloOpcode::kCustomCall: instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target()); + proto.custom_call_target(), + proto.custom_call_opaque()); if (proto.has_window()) { static_cast<HloCustomCallInstruction*>(instruction.get()) ->set_window(proto.window()); @@ -1108,9 +1109,9 @@ bool HloInstruction::HasSideEffect() const { /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target) { - return absl::make_unique<HloCustomCallInstruction>(shape, operands, - custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque) { + return absl::make_unique<HloCustomCallInstruction>( + shape, operands, custom_call_target, opaque); } /* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d615df0831..009bd3bab3 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -718,10 +718,11 @@ class HloInstruction { HloComputation* computation); // Creates a custom call instruction that applies the given custom call target - // to the given operands. "shape" is the resultant shape. + // to the given operands. "opaque" can be an arbitrary string with a + // backend-specific interpretation. "shape" is the resultant shape. static std::unique_ptr<HloInstruction> CreateCustomCall( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target); + absl::string_view custom_call_target, absl::string_view opaque = ""); // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index e92882c22a..cd71bc3323 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1830,9 +1830,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl( HloCustomCallInstruction::HloCustomCallInstruction( const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target) + absl::string_view custom_call_target, absl::string_view opaque) : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), feature_group_count_(1) { for (auto operand : operands) { AppendOperand(operand); @@ -1849,6 +1850,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { *convolution_dimension_numbers_; } proto.set_custom_call_target(custom_call_target_); + proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); return proto; } @@ -1872,6 +1874,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl( // an HloComputation. extra.push_back( StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\"")); + // If the opaque string becomes enormous we may want to reconsider printing + // this inline and consider other options. + if (!opaque_.empty()) { + extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); + } return extra; } @@ -1897,7 +1904,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath( if (feature_group_count_ != casted_other.feature_group_count_) { return false; } - return custom_call_target_ == casted_other.custom_call_target_; + return custom_call_target_ == casted_other.custom_call_target_ && + opaque_ == casted_other.opaque_; } std::unique_ptr<HloInstruction> @@ -1905,7 +1913,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const { auto cloned = absl::make_unique<HloCustomCallInstruction>( - shape, new_operands, custom_call_target()); + shape, new_operands, custom_call_target(), opaque()); if (window_ != nullptr) { cloned->set_window(*window_); } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index 2d7bc83855..9c22f5db7e 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1070,7 +1070,8 @@ class HloCustomCallInstruction : public HloInstruction { public: explicit HloCustomCallInstruction(const Shape& shape, absl::Span<HloInstruction* const> operands, - absl::string_view custom_call_target); + absl::string_view custom_call_target, + absl::string_view opaque); const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1090,6 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction { convolution_dimension_numbers_ = absl::make_unique<ConvolutionDimensionNumbers>(dnums); } + const string& opaque() const { return opaque_; } const string& custom_call_target() const { return custom_call_target_; } void set_feature_group_count(int64 feature_group_count) { feature_group_count_ = feature_group_count; @@ -1109,8 +1111,10 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl( const Shape& shape, absl::Span<HloInstruction* const> new_operands, HloCloneContext* context) const override; - // Name of a global symbol to call, only present for kCustomCall. + // Name of a global symbol to call. string custom_call_target_; + // Opaque string interpreted by the backend. + string opaque_; // Describes the window in a windowed operation such as convolution. std::unique_ptr<Window> window_; // Describes the dimension numbers used for a convolution. diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 37197b273b..25b70740e3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -1266,11 +1266,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } case HloOpcode::kCustomCall: { optional<string> custom_call_target; + optional<string> opaque; optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; optional<int64> feature_group_count; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; + attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window}; attrs["dim_labels"] = {/*required=*/false, AttrTy::kConvolutionDimensionNumbers, &dnums}; @@ -1279,8 +1281,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( - shape, operands, *custom_call_target)); + instruction = builder->AddInstruction( + HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); if (window.has_value()) { instruction->set_window(*window); } diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index cca50fab54..96db96bdb9 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1004,6 +1004,18 @@ ENTRY CustomCall { )" }, +// CustomCall with opaque value. +{ +"CustomCallWithOpaque", +R"(HloModule custom_call + +ENTRY CustomCall { + constant = f32[1]{0} constant({12345}) + ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque" +} + +)" +}, // Variables with non-default names { "NonDefaultNames", |