aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-27 10:53:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 10:56:27 -0700
commit6d41787c32483b28f8c93973f28d4d078ea0b37e (patch)
tree1b310e402a71a8b79b24f33080b034b75c4df32b /tensorflow/compiler
parent334244be6864dd1dbec9bc8bb4996cc286a8e3e3 (diff)
Add opaque field to custom call.
The intent of this field is to enable more information to be encoded in the custom call and passed through to the backend. PiperOrigin-RevId: 214800539
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc8
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h24
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto8
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc12
9 files changed, 67 insertions, 28 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 95ff6432a5..5277de6a85 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1278,7 +1278,7 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
XlaOp XlaBuilder::CustomCall(const string& call_target_name,
absl::Span<const XlaOp> operands,
- const Shape& shape) {
+ const Shape& shape, const string& opaque) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1289,6 +1289,7 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
}
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
+ instr.set_custom_call_opaque(opaque);
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -2681,8 +2682,9 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
}
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape) {
- return builder->CustomCall(call_target_name, operands, shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index d0c59fa6f2..1da6ddd318 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -577,11 +577,9 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- // During code generation, a call instruction is emitted which targets a
- // symbol with the name |call_target_name|. The |operands| are passed to the
- // call instruction. |shape| is the resultant shape.
XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1195,7 +1193,8 @@ class XlaBuilder {
friend XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1717,12 +1716,17 @@ XlaOp OutfeedWithToken(const XlaOp& operand, const XlaOp& token,
XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
absl::Span<const XlaOp> operands);
-// Enqueues a custom call instruction onto the computation.
-// During code generation, a call instruction is emitted which targets a
-// symbol with the name |call_target_name|. The |operands| are passed to the
-// call instruction. |shape| is the resultant shape.
+// Enqueues a custom call instruction onto the computation. A custom call
+// invokes code external to XLA. The |operands| are passed to the external code,
+// and the external code is expected to produce a result of the given
+// |shape|. The exact mechanism is backend-specific. For example, in the CPU
+// backend, a call instruction is emitted which targets a symbol with the name
+// |call_target_name|. |call_target_name| and |opaque| can arbitrary strings,
+// but |call_target_name| should be short as it may be used in labels. |opaque|
+// can encode arbitrarily large amounts of information.
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape);
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ const string& opaque = "");
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index b19ec12638..caaca16f71 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 53
+// Next ID: 54
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -124,9 +124,13 @@ message HloInstructionProto {
// The string representation of the infeed configuration.
bytes infeed_config = 27;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a external target (eg, global symbol) to call, only present for
+ // kCustomCall.
string custom_call_target = 28;
+ // Opaque string, only present for kCustomCall.
+ string custom_call_opaque = 53;
+
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index f7ec854d80..23787dbc8a 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -379,7 +379,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
break;
case HloOpcode::kCustomCall:
instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target());
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -1108,9 +1109,9 @@ bool HloInstruction::HasSideEffect() const {
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target) {
- return absl::make_unique<HloCustomCallInstruction>(shape, operands,
- custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque);
}
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index d615df0831..009bd3bab3 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -718,10 +718,11 @@ class HloInstruction {
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
- // to the given operands. "shape" is the resultant shape.
+ // to the given operands. "opaque" can be an arbitrary string with a
+ // backend-specific interpretation. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target, absl::string_view opaque = "");
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e92882c22a..cd71bc3323 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1830,9 +1830,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target)
+ absl::string_view custom_call_target, absl::string_view opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -1849,6 +1850,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1872,6 +1874,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
// an HloComputation.
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ // If the opaque string becomes enormous we may want to reconsider printing
+ // this inline and consider other options.
+ if (!opaque_.empty()) {
+ extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
+ }
return extra;
}
@@ -1897,7 +1904,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (feature_group_count_ != casted_other.feature_group_count_) {
return false;
}
- return custom_call_target_ == casted_other.custom_call_target_;
+ return custom_call_target_ == casted_other.custom_call_target_ &&
+ opaque_ == casted_other.opaque_;
}
std::unique_ptr<HloInstruction>
@@ -1905,7 +1913,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
- shape, new_operands, custom_call_target());
+ shape, new_operands, custom_call_target(), opaque());
if (window_ != nullptr) {
cloned->set_window(*window_);
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 2d7bc83855..9c22f5db7e 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1070,7 +1070,8 @@ class HloCustomCallInstruction : public HloInstruction {
public:
explicit HloCustomCallInstruction(const Shape& shape,
absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target);
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1090,6 +1091,7 @@ class HloCustomCallInstruction : public HloInstruction {
convolution_dimension_numbers_ =
absl::make_unique<ConvolutionDimensionNumbers>(dnums);
}
+ const string& opaque() const { return opaque_; }
const string& custom_call_target() const { return custom_call_target_; }
void set_feature_group_count(int64 feature_group_count) {
feature_group_count_ = feature_group_count;
@@ -1109,8 +1111,10 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<HloInstruction> CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const override;
- // Name of a global symbol to call, only present for kCustomCall.
+ // Name of a global symbol to call.
string custom_call_target_;
+ // Opaque string interpreted by the backend.
+ string opaque_;
// Describes the window in a windowed operation such as convolution.
std::unique_ptr<Window> window_;
// Describes the dimension numbers used for a convolution.
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 37197b273b..25b70740e3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -1266,11 +1266,13 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
case HloOpcode::kCustomCall: {
optional<string> custom_call_target;
+ optional<string> opaque;
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
+ attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
attrs["window"] = {/*required=*/false, AttrTy::kWindow, &window};
attrs["dim_labels"] = {/*required=*/false,
AttrTy::kConvolutionDimensionNumbers, &dnums};
@@ -1279,8 +1281,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
- shape, operands, *custom_call_target));
+ instruction = builder->AddInstruction(
+ HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
if (window.has_value()) {
instruction->set_window(*window);
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index cca50fab54..96db96bdb9 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1004,6 +1004,18 @@ ENTRY CustomCall {
)"
},
+// CustomCall with opaque value.
+{
+"CustomCallWithOpaque",
+R"(HloModule custom_call
+
+ENTRY CustomCall {
+ constant = f32[1]{0} constant({12345})
+ ROOT custom-call = f32[1,2,3]{0,2,1} custom-call(constant), custom_call_target="foo\"bar", opaque="this string is opaque"
+}
+
+)"
+},
// Variables with non-default names
{
"NonDefaultNames",