aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instructions.cc
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-08 14:26:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 14:34:02 -0700
commit396a8a4105edd409d0821c4d5d0b920b315ffb72 (patch)
tree428350d427ffb29470e284077a2734b273b7cc4d /tensorflow/compiler/xla/service/hlo_instructions.cc
parentbc5635dc3ac78007caee88fabd81d23ad945b637 (diff)
Add custom call with layout constraints.
Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instructions.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc33
1 files changed, 32 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 152d8eacdb..2ec233eaec 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1825,7 +1825,24 @@ HloCustomCallInstruction::HloCustomCallInstruction(
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
opaque_(opaque.begin(), opaque.end()),
- feature_group_count_(1) {
+ feature_group_count_(1),
+ layout_constrained_(false) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloCustomCallInstruction::HloCustomCallInstruction(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target, absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout)
+ : HloInstruction(HloOpcode::kCustomCall, shape),
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
+ feature_group_count_(1),
+ layout_constrained_(true),
+ operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
+ operand_shapes_with_layout.end()) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1843,6 +1860,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
proto.set_custom_call_target(custom_call_target_);
proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
+ if (layout_constrained()) {
+ proto.set_constrain_layout(true);
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ *proto.add_operand_shapes_with_layout() = shape;
+ }
+ }
return proto;
}
@@ -1870,6 +1893,14 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (!opaque_.empty()) {
extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
}
+ if (layout_constrained()) {
+ std::vector<string> shape_strings;
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
+ }
+ extra.push_back(StrCat("operand_layout_constraints={",
+ StrJoin(shape_strings, ", "), "}"));
+ }
return extra;
}