diff options
Diffstat (limited to 'tensorflow/compiler/xla/client/xla_builder.cc')
-rw-r--r-- | tensorflow/compiler/xla/client/xla_builder.cc | 43 |
1 files changed, 39 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 6b31831010..e7cf9ae363 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1279,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) { }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - absl::Span<const XlaOp> operands, - const Shape& shape, const string& opaque) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span<const XlaOp> operands, + const Shape& shape, const string& opaque, + absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1293,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape; + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -2690,7 +2716,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span<const XlaOp> operands, const Shape& shape, const string& opaque) { - return builder->CustomCall(call_target_name, operands, shape, opaque); + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span<const XlaOp> operands, const Shape& shape, + absl::Span<const Shape> operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, |