aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/client/xla_builder.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.cc')
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc43
1 files changed, 39 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 6b31831010..e7cf9ae363 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1279,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
});
}
-XlaOp XlaBuilder::CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const Shape& shape, const string& opaque) {
+XlaOp XlaBuilder::CustomCall(
+ const string& call_target_name, absl::Span<const XlaOp> operands,
+ const Shape& shape, const string& opaque,
+ absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1293,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
instr.set_custom_call_opaque(opaque);
+ if (operand_shapes_with_layout.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument(
+ "Result shape must have layout for custom call with constrained "
+ "layout.");
+ }
+ if (operands.size() != operand_shapes_with_layout->size()) {
+ return InvalidArgument(
+ "Must specify a shape with layout for each operand for custom call "
+ "with constrained layout; given %d shapes, expected %d",
+ operand_shapes_with_layout->size(), operands.size());
+ }
+ instr.set_constrain_layout(true);
+ int64 operand_num = 0;
+ for (const Shape& operand_shape : *operand_shapes_with_layout) {
+ if (!LayoutUtil::HasLayout(operand_shape)) {
+ return InvalidArgument(
+ "No layout specified for operand %d for custom call with "
+ "constrained layout.",
+ operand_num);
+ }
+ *instr.add_operand_shapes_with_layout() = operand_shape;
+ ++operand_num;
+ }
+ }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -2690,7 +2716,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque) {
- return builder->CustomCall(call_target_name, operands, shape, opaque);
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ /*operand_shapes_with_layout=*/absl::nullopt);
+}
+
+XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ operand_shapes_with_layout);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,