aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-09-27 10:53:36 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-09-27 10:56:27 -0700
commit6d41787c32483b28f8c93973f28d4d078ea0b37e (patch)
tree1b310e402a71a8b79b24f33080b034b75c4df32b /tensorflow/compiler/xla/service/hlo_instructions.cc
parent334244be6864dd1dbec9bc8bb4996cc286a8e3e3 (diff)
Add opaque field to custom call.
The intent of this field is to enable more information to be encoded in the custom call and passed through to the backend. PiperOrigin-RevId: 214800539
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc14
1 files changed, 11 insertions, 3 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index e92882c22a..cd71bc3323 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1830,9 +1830,10 @@ HloSelectAndScatterInstruction::CloneWithNewOperandsImpl(
HloCustomCallInstruction::HloCustomCallInstruction(
const Shape& shape, absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target)
+ absl::string_view custom_call_target, absl::string_view opaque)
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
feature_group_count_(1) {
for (auto operand : operands) {
AppendOperand(operand);
@@ -1849,6 +1850,7 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
*convolution_dimension_numbers_;
}
proto.set_custom_call_target(custom_call_target_);
+ proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
return proto;
}
@@ -1872,6 +1874,11 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
// an HloComputation.
extra.push_back(
StrCat("custom_call_target=\"", CEscape(custom_call_target_), "\""));
+ // If the opaque string becomes enormous we may want to reconsider printing
+ // this inline and consider other options.
+ if (!opaque_.empty()) {
+ extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
+ }
return extra;
}
@@ -1897,7 +1904,8 @@ bool HloCustomCallInstruction::IdenticalSlowPath(
if (feature_group_count_ != casted_other.feature_group_count_) {
return false;
}
- return custom_call_target_ == casted_other.custom_call_target_;
+ return custom_call_target_ == casted_other.custom_call_target_ &&
+ opaque_ == casted_other.opaque_;
}
std::unique_ptr<HloInstruction>
@@ -1905,7 +1913,7 @@ HloCustomCallInstruction::CloneWithNewOperandsImpl(
const Shape& shape, absl::Span<HloInstruction* const> new_operands,
HloCloneContext* context) const {
auto cloned = absl::make_unique<HloCustomCallInstruction>(
- shape, new_operands, custom_call_target());
+ shape, new_operands, custom_call_target(), opaque());
if (window_ != nullptr) {
cloned->set_window(*window_);
}