aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Mark Heffernan <meheff@google.com>2018-10-08 14:26:43 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-08 14:34:02 -0700
commit396a8a4105edd409d0821c4d5d0b920b315ffb72 (patch)
tree428350d427ffb29470e284077a2734b273b7cc4d
parentbc5635dc3ac78007caee88fabd81d23ad945b637 (diff)
Add custom call with layout constraints.
Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615
-rw-r--r--tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc22
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc43
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h22
-rw-r--r--tensorflow/compiler/xla/layout_util.cc6
-rw-r--r--tensorflow/compiler/xla/layout_util.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc10
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h2
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto9
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc28
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h10
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc33
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h32
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc101
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc67
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc22
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc108
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc190
-rw-r--r--tensorflow/compiler/xla/shape_util.cc2
-rw-r--r--tensorflow/compiler/xla/tests/custom_call_test.cc50
20 files changed, 650 insertions, 124 deletions
diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
index 3d81ae9eb8..f210bfbd88 100644
--- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
+++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc
@@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel {
xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0<int32>(dim)));
}
- xla::Shape xla_shape =
- xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes());
+ // The argmax function expects row-major layout.
+ xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout(
+ xla::S64, output_shape.dim_sizes());
+ std::vector<xla::Shape> arg_shapes;
+ for (const xla::XlaOp& arg : args) {
+ auto shape_status = b.GetShape(arg);
+ OP_REQUIRES_OK(ctx, shape_status.status());
+ xla::Shape arg_shape = shape_status.ConsumeValueOrDie();
+ *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout(
+ xla::ShapeUtil::Rank(arg_shape));
+ arg_shapes.push_back(std::move(arg_shape));
+ }
// Tell XLA to call the custom code, defined in
// index_ops_kernel_argmax_float_1d.cc.
xla::XlaOp output;
switch (input_shape.dims()) {
case 1:
- output =
- xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
case 2:
- output =
- xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape);
+ output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args,
+ xla_shape, arg_shapes);
break;
default:
OP_REQUIRES(ctx, false,
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 6b31831010..e7cf9ae363 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -1279,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span<const XlaOp> tokens) {
});
}
-XlaOp XlaBuilder::CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands,
- const Shape& shape, const string& opaque) {
+XlaOp XlaBuilder::CustomCall(
+ const string& call_target_name, absl::Span<const XlaOp> operands,
+ const Shape& shape, const string& opaque,
+ absl::optional<absl::Span<const Shape>> operand_shapes_with_layout) {
return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (absl::StartsWith(call_target_name, "$")) {
@@ -1293,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name,
*instr.mutable_shape() = shape;
instr.set_custom_call_target(call_target_name);
instr.set_custom_call_opaque(opaque);
+ if (operand_shapes_with_layout.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return InvalidArgument(
+ "Result shape must have layout for custom call with constrained "
+ "layout.");
+ }
+ if (operands.size() != operand_shapes_with_layout->size()) {
+ return InvalidArgument(
+ "Must specify a shape with layout for each operand for custom call "
+ "with constrained layout; given %d shapes, expected %d",
+ operand_shapes_with_layout->size(), operands.size());
+ }
+ instr.set_constrain_layout(true);
+ int64 operand_num = 0;
+ for (const Shape& operand_shape : *operand_shapes_with_layout) {
+ if (!LayoutUtil::HasLayout(operand_shape)) {
+ return InvalidArgument(
+ "No layout specified for operand %d for custom call with "
+ "constrained layout.",
+ operand_num);
+ }
+ *instr.add_operand_shapes_with_layout() = operand_shape;
+ ++operand_num;
+ }
+ }
return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands);
});
}
@@ -2690,7 +2716,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation,
XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque) {
- return builder->CustomCall(call_target_name, operands, shape, opaque);
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ /*operand_shapes_with_layout=*/absl::nullopt);
+}
+
+XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque) {
+ return builder->CustomCall(call_target_name, operands, shape, opaque,
+ operand_shapes_with_layout);
}
XlaOp Complex(const XlaOp& real, const XlaOp& imag,
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 2e14e47a35..9ceede7a79 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -577,9 +577,10 @@ class XlaBuilder {
absl::Span<const XlaOp> operands);
// Enqueues a custom call instruction onto the computation.
- XlaOp CustomCall(const string& call_target_name,
- absl::Span<const XlaOp> operands, const Shape& shape,
- const string& opaque);
+ XlaOp CustomCall(
+ const string& call_target_name, absl::Span<const XlaOp> operands,
+ const Shape& shape_with_layout, const string& opaque,
+ absl::optional<absl::Span<const Shape>> operand_shapes_with_layout);
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
@@ -1197,6 +1198,10 @@ class XlaBuilder {
friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque);
+ friend XlaOp CustomCallWithLayout(
+ XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands, const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout, const string& opaque);
friend XlaOp Complex(const XlaOp& real, const XlaOp& imag,
absl::Span<const int64> broadcast_dimensions);
friend XlaOp Conj(const XlaOp& operand);
@@ -1732,6 +1737,17 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name,
absl::Span<const XlaOp> operands, const Shape& shape,
const string& opaque = "");
+// Overload which constructs a custom call with fixed layouts. The operands will
+// have the layouts specified by |operand_shapes_with_layout| when provided to
+// external code, and the external code is expected to produce a result with the
+// layout specified by |shape_with_layout|. All shapes in |shape_with_layout|
+// and |operand_shapes_with_layout| must have layouts.
+XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name,
+ absl::Span<const XlaOp> operands,
+ const Shape& shape_with_layout,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ const string& opaque = "");
+
// The following methods enqueue element-wise binary arithmetic operations
// onto the computation. The shapes of the operands have to match unless one
// of the operands is a scalar, or an explicit broadcast dimension is given
diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc
index d310335618..3c8db9aa45 100644
--- a/tensorflow/compiler/xla/layout_util.cc
+++ b/tensorflow/compiler/xla/layout_util.cc
@@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer(
return layout;
}
+/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) {
+ std::vector<int64> layout(rank);
+ std::iota(layout.rbegin(), layout.rend(), static_cast<int64>(0));
+ return MakeLayout(layout);
+}
+
/* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor) {
Layout layout;
diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h
index b78883c2d8..af032b1cae 100644
--- a/tensorflow/compiler/xla/layout_util.h
+++ b/tensorflow/compiler/xla/layout_util.h
@@ -40,6 +40,10 @@ class LayoutUtil {
static Layout MakeLayoutFromMajorToMinor(
absl::Span<const int64> major_to_minor);
+ // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major
+ // dimensions.
+ static Layout MakeDescendingLayout(int64 rank);
+
// Creates a sparse layout with the given maximum number of elements. (This is
// a convenience function for protobuf construction.)
static Layout MakeSparseLayout(int64 max_sparse_elements);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 1ffe855750..8c9a8adc61 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -213,16 +213,6 @@ Status GpuLayoutAssignment::AddBackendConstraints(
return Status::OK();
}
-bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout(
- const HloInstruction* instruction) {
- // - Inputs to cudnn batchnorm custom calls don't need the major-first layout
- // (i.e. {n, n-1, ...0}) -- we can handle any layout.
- // - Inputs to cudnn convolution require custom layouts handled in
- // AddBackendConstraints.
- return !IsCustomCallToDnnBatchNorm(*instruction) &&
- !IsCustomCallToDnnConvolution(*instruction);
-}
-
Status GpuLayoutAssignment::PropagateOperandConstraint(
const OperandLayoutConstraint& layout_constraint,
LayoutConstraints* constraints) {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
index 4ba7989e9c..6a48e55fd2 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h
@@ -46,8 +46,6 @@ class GpuLayoutAssignment : public LayoutAssignment {
Status PropagateBufferConstraint(
const BufferLayoutConstraint& buffer_constraint,
LayoutConstraints* constraints) override;
- bool CustomCallRequiresMajorFirstLayout(
- const HloInstruction* instruction) override;
private:
Status AddBackendConstraintsToDnnConvCustomCall(
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 1ea26ddd5b..a0eb9e6ddc 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 56
+// Next ID: 58
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -184,6 +184,13 @@ message HloInstructionProto {
// Sharding for kDomain instructions.
xla.OpSharding domain_entry_sharding = 54;
xla.OpSharding domain_exit_sharding = 55;
+
+ // For custom call this indicates that the layouts are constrained. If
+ // constrain_layout is true then the 'shape' field must contain a layout, and
+ // 'operand_shapes_with_layout' must contain a shape with layout for each
+ // operand.
+ bool constrain_layout = 56;
+ repeated Shape operand_shapes_with_layout = 57;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 2f6db7cd7c..5c3908a9a4 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -396,9 +396,22 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
operands(1), operands(2), computations(1));
break;
case HloOpcode::kCustomCall:
- instruction = CreateCustomCall(proto.shape(), all_operands(),
- proto.custom_call_target(),
- proto.custom_call_opaque());
+ if (proto.constrain_layout()) {
+ // A proto RepeatedPtrField cannot be converted to a Span (it is a
+ // vector of pointers essentially) so create a vector of shapes to pass
+ // in.
+ std::vector<Shape> operand_shapes;
+ for (const Shape& shape : proto.operand_shapes_with_layout()) {
+ operand_shapes.push_back(shape);
+ }
+ instruction = CreateCustomCall(
+ proto.shape(), all_operands(), proto.custom_call_target(),
+ operand_shapes, proto.custom_call_opaque());
+ } else {
+ instruction = CreateCustomCall(proto.shape(), all_operands(),
+ proto.custom_call_target(),
+ proto.custom_call_opaque());
+ }
if (proto.has_window()) {
static_cast<HloCustomCallInstruction*>(instruction.get())
->set_window(proto.window());
@@ -1142,6 +1155,15 @@ bool HloInstruction::HasSideEffect() const {
shape, operands, custom_call_target, opaque);
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateCustomCall(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ absl::string_view opaque) {
+ return absl::make_unique<HloCustomCallInstruction>(
+ shape, operands, custom_call_target, opaque, operand_shapes_with_layout);
+}
+
/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateTuple(
absl::Span<HloInstruction* const> elements) {
std::vector<Shape> element_shapes;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 374862c4b6..44f776ebac 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -734,6 +734,16 @@ class HloInstruction {
const Shape& shape, absl::Span<HloInstruction* const> operands,
absl::string_view custom_call_target, absl::string_view opaque = "");
+ // Overload which constrains the layouts of the operand and result. 'shape'
+ // and 'operand_shapes_with_layout' must have layouts.
+ // 'operand_shapes_with_layout' must have a compatible element for each
+ // operand.
+ static std::unique_ptr<HloInstruction> CreateCustomCall(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::Span<const Shape> operand_shapes_with_layout,
+ absl::string_view opaque = "");
+
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index 152d8eacdb..2ec233eaec 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -1825,7 +1825,24 @@ HloCustomCallInstruction::HloCustomCallInstruction(
: HloInstruction(HloOpcode::kCustomCall, shape),
custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
opaque_(opaque.begin(), opaque.end()),
- feature_group_count_(1) {
+ feature_group_count_(1),
+ layout_constrained_(false) {
+ for (auto operand : operands) {
+ AppendOperand(operand);
+ }
+}
+
+HloCustomCallInstruction::HloCustomCallInstruction(
+ const Shape& shape, absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target, absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout)
+ : HloInstruction(HloOpcode::kCustomCall, shape),
+ custom_call_target_(custom_call_target.begin(), custom_call_target.end()),
+ opaque_(opaque.begin(), opaque.end()),
+ feature_group_count_(1),
+ layout_constrained_(true),
+ operand_shapes_with_layout_(operand_shapes_with_layout.begin(),
+ operand_shapes_with_layout.end()) {
for (auto operand : operands) {
AppendOperand(operand);
}
@@ -1843,6 +1860,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const {
proto.set_custom_call_target(custom_call_target_);
proto.set_custom_call_opaque(opaque_);
proto.set_feature_group_count(feature_group_count_);
+ if (layout_constrained()) {
+ proto.set_constrain_layout(true);
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ *proto.add_operand_shapes_with_layout() = shape;
+ }
+ }
return proto;
}
@@ -1870,6 +1893,14 @@ std::vector<string> HloCustomCallInstruction::ExtraAttributesToStringImpl(
if (!opaque_.empty()) {
extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\""));
}
+ if (layout_constrained()) {
+ std::vector<string> shape_strings;
+ for (const Shape& shape : operand_shapes_with_layout_) {
+ shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape));
+ }
+ extra.push_back(StrCat("operand_layout_constraints={",
+ StrJoin(shape_strings, ", "), "}"));
+ }
return extra;
}
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index e169604072..4c5fc759a3 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1053,10 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction {
class HloCustomCallInstruction : public HloInstruction {
public:
- explicit HloCustomCallInstruction(const Shape& shape,
- absl::Span<HloInstruction* const> operands,
- absl::string_view custom_call_target,
- absl::string_view opaque);
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque);
+
+ // Constructor for a custom call with constrained layout. 'shape' and
+ // 'operands_with_layout' must all have layouts.
+ HloCustomCallInstruction(const Shape& shape,
+ absl::Span<HloInstruction* const> operands,
+ absl::string_view custom_call_target,
+ absl::string_view opaque,
+ absl::Span<const Shape> operand_shapes_with_layout);
+
const Window& window() const override {
CHECK(window_ != nullptr);
return *window_;
@@ -1085,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction {
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const override;
+ // Returns whether the result and operand layouts are constrained.
+ bool layout_constrained() const { return layout_constrained_; }
+
+ // Returns the shapes (with layout) of the operands. CHECKs if this custom
+ // call does not have constrained layouts.
+ const std::vector<Shape>& operand_shapes_with_layout() const {
+ CHECK(layout_constrained());
+ return operand_shapes_with_layout_;
+ }
+
private:
std::vector<string> ExtraAttributesToStringImpl(
const HloPrintOptions& options) const override;
@@ -1106,6 +1125,11 @@ class HloCustomCallInstruction : public HloInstruction {
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
// The number of feature groups. This is used for grouped convolutions.
int64 feature_group_count_;
+ // Whether the result and operand layouts are constrained.
+ bool layout_constrained_;
+ // For layout-constrained custom calls, this vector holds the shape with
+ // layout for each operand.
+ std::vector<Shape> operand_shapes_with_layout_;
};
class HloPadInstruction : public HloInstruction {
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index dd62988bcc..96f9ff6654 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -174,6 +174,7 @@ class HloParser {
kDistribution,
kDomain,
kPrecisionList,
+ kShapeList
};
struct AttrConfig {
@@ -240,6 +241,7 @@ class HloParser {
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
+ bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -1341,6 +1343,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
+ optional<std::vector<Shape>> operand_layout_constraints;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
@@ -1349,12 +1352,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ attrs["operand_layout_constraints"] = {
+ /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
- opaque.has_value() ? *opaque : ""));
+ if (operand_layout_constraints.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return Error(lexer_.GetLoc(),
+ "Layout must be set on layout-constrained custom call");
+ }
+ if (operands.size() != operand_layout_constraints->size()) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Expected ", operands.size(),
+ " operand layout constraints, ",
+ operand_layout_constraints->size(), " given"));
+ }
+ for (int64 i = 0; i < operands.size(); ++i) {
+ const Shape& operand_shape_with_layout =
+ (*operand_layout_constraints)[i];
+ if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(
+ operand_shape_with_layout),
+ " for operand ", i, " does not have a layout"));
+ }
+ if (!ShapeUtil::Compatible(operand_shape_with_layout,
+ operands[i]->shape())) {
+ return Error(
+ lexer_.GetLoc(),
+ StrCat(
+ "Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
+ " for operand ", i,
+ " is not compatible with operand shape ",
+ ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
+ }
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target, *operand_layout_constraints,
+ opaque.has_value() ? *opaque : ""));
+ } else {
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
+ }
if (window.has_value()) {
instruction->set_window(*window);
}
@@ -2533,6 +2576,15 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kShapeList: {
+ std::vector<Shape> result;
+ if (!ParseShapeList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2825,6 +2877,23 @@ bool HloParser::ParsePrecisionList(
parse_and_add_item);
}
+// shapelist ::= '{' shapes '}'
+// precision_elements
+// ::= /*empty*/
+// ::= shape (',' shape)*
+bool HloParser::ParseShapeList(std::vector<Shape>* result) {
+ auto parse_and_add_item = [&]() {
+ Shape shape;
+ if (!ParseShape(&shape)) {
+ return false;
+ }
+ result->push_back(std::move(shape));
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2832,23 +2901,15 @@ bool HloParser::ParsePrecisionList(
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result) {
- if (!ParseToken(start, StrCat("expects an int64 list starting with ",
- TokKindToString(start)))) {
- return false;
- }
- if (lexer_.GetKind() == end) {
- // empty
- } else {
- do {
- tensorflow::int64 i;
- if (!ParseInt64(&i)) {
- return false;
- }
- result->push_back(i);
- } while (EatIfPresent(delim));
- }
- return ParseToken(
- end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
+ auto parse_and_add_item = [&]() {
+ tensorflow::int64 i;
+ if (!ParseInt64(&i)) {
+ return false;
+ }
+ result->push_back(i);
+ return true;
+ };
+ return ParseList(start, end, delim, parse_and_add_item);
}
bool HloParser::ParseList(const TokKind start, const TokKind end,
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 255123d331..17538c05bc 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -804,6 +804,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] {
)"
},
+// CustomCallWithLayoutConstraints
+{
+"CustomCallWithLayoutConstraints",
+R"(HloModule CustomCallWithLayoutConstraints
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsNoOperands
+{
+"CustomCallWithLayoutConstraintsNoOperands",
+R"(HloModule CustomCallWithLayoutConstraintsNoOperands
+
+ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] {
+ ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+
+)"
+},
+// CustomCallWithLayoutConstraintsTupleShapes
+{
+"CustomCallWithLayoutConstraintsTupleShapes",
+R"(HloModule CustomCallWithLayoutConstraintsTupleShapes
+
+ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) {
+ %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}}
+}
+
+)"
+},
});
// clang-format on
}
@@ -2069,5 +2106,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] {
op::Broadcast(), op::Multiply(), op::Add()));
}
+TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) {
+ const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints
+
+ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "Expected 2 operand layout constraints, 1 given");
+}
+
+TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) {
+ const string original = R"(HloModule CustomCallIncompatibleOperandConstraints
+
+ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] {
+ %p0 = f32[42,2,3]{0,1,2} parameter(0)
+ %p1 = f32[123,4]{0,1} parameter(1)
+ ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}}
+}
+
+)";
+ ExpectHasSubstr(ParseHloString(original).status().error_message(),
+ "operand 1 is not compatible with operand shape");
+}
+
+// custom call incompatible shape.
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 496fe1795d..be3bee5975 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -360,7 +360,27 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) {
return CheckShape(call, call->to_apply()->root_instruction()->shape());
}
-Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); }
+Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<const HloCustomCallInstruction>(instruction);
+ TF_RET_CHECK(custom_call != nullptr);
+ if (custom_call->layout_constrained()) {
+ // If the layout is constrained, verify all the respective shapes have
+ // layouts and that the constrained operand shapes match the shapes of the
+ // operands.
+ TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape()));
+ TF_RET_CHECK(custom_call->operand_count() ==
+ custom_call->operand_shapes_with_layout().size());
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ const Shape& operand_shape_with_layout =
+ custom_call->operand_shapes_with_layout()[i];
+ TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(),
+ operand_shape_with_layout));
+ TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout));
+ }
+ }
+ return Status::OK();
+}
Status ShapeVerifier::HandleSlice(HloInstruction* slice) {
return CheckShape(slice,
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index cc4a342e9d..ad65b147c1 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints(
return Status::OK();
}
+namespace {
+
+bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ return custom_call != nullptr && custom_call->layout_constrained();
+}
+
+} // namespace
+
Status LayoutAssignment::AddMandatoryConstraints(
const ComputationLayout* computation_layout,
ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
@@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
// Constrain layouts of instructions which define values with pre-existing
// layouts.
for (auto* instruction : computation->instructions()) {
- Shape const* shape_with_layout = nullptr;
if (instruction->opcode() == HloOpcode::kInfeed) {
// Infeed layouts must match the layout of the original inserted
// instruction.
@@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints(
if (parameter_layout.LayoutIsSet()) {
// Parameter layouts must match the respective layout in
// ComputationLayout, if there is one.
- shape_with_layout = &parameter_layout.shape();
+ TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
+ parameter_layout.shape(), instruction));
}
}
- }
- if (shape_with_layout != nullptr) {
+ } else if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(*shape_with_layout, instruction));
- }
-
- if (instruction->opcode() == HloOpcode::kSend ||
- instruction->opcode() == HloOpcode::kRecv) {
+ constraints->SetInstructionLayout(custom_call->shape(), custom_call));
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
+ custom_call->operand_shapes_with_layout()[i], custom_call, i));
+ }
+ } else if (instruction->opcode() == HloOpcode::kSend ||
+ instruction->opcode() == HloOpcode::kRecv) {
CHECK(get_channel_constraints(instruction))
<< "Multi-module layout assignment requires ChannelLayoutConstraints";
int64 channel_id = instruction->channel_id();
@@ -621,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
false_computation_layout.parameter_shape(0), instruction, 2,
/*mandatory=*/true));
- } else if (instruction->opcode() == HloOpcode::kCustomCall) {
- if (!CustomCallRequiresMajorFirstLayout(instruction)) {
- continue;
- }
- // Add constraints for kCustomCall instruction operands and instructions.
- // For now we only support major-first layouts for all inputs and outputs.
- Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout(
- instruction->shape().element_type(),
- AsInt64Slice(instruction->shape().dimensions()));
- TF_RETURN_IF_ERROR(
- constraints->SetInstructionLayout(result_shape, instruction));
- for (int64 i = 0; i < instruction->operand_count(); ++i) {
- const Shape& operand_shape = instruction->operand(i)->shape();
- // Opaque operands don't get a layout constraint.
- if (ShapeUtil::IsOpaque(operand_shape)) {
- continue;
- }
-
- Shape row_major_operand_shape =
- ShapeUtil::MakeShapeWithDescendingLayout(
- operand_shape.element_type(),
- AsInt64Slice(operand_shape.dimensions()));
- TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
- row_major_operand_shape, instruction, i));
- }
}
}
// Finally set the result layout to match ComputationLayout, if there is one.
@@ -676,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call,
return Status::OK();
}
-// Custom calls have fixed input and output layouts.
-Status CheckCustomCallLayout(HloInstruction* custom_call) {
- for (const HloInstruction* operand : custom_call->operands()) {
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(operand->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout()));
+// Operands of layout-constrained custom calls must match the expected
+// constrained layouts.
+Status CheckCustomCallLayout(HloInstruction* instruction) {
+ if (IsLayoutConstrainedCustomCall(instruction)) {
+ const HloCustomCallInstruction* custom_call =
+ DynCast<HloCustomCallInstruction>(instruction);
+ for (int64 i = 0; i < custom_call->operand_count(); ++i) {
+ TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual(
+ custom_call->operand(i)->shape(),
+ custom_call->operand_shapes_with_layout()[i]));
+ }
}
- TF_RET_CHECK(
- ShapeUtil::IsOpaque(custom_call->shape()) ||
- LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout()));
return Status::OK();
}
@@ -932,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
FindOrDie(computation_layouts_, instruction->to_apply())));
break;
case HloOpcode::kCustomCall:
- if (CustomCallRequiresMajorFirstLayout(instruction)) {
- TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
- }
+ TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction));
break;
case HloOpcode::kFusion:
TF_RETURN_IF_ERROR(CheckFusionLayout(instruction));
@@ -1554,11 +1542,11 @@ Status LayoutAssignment::CalculateComputationLayout(
Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
// Clear existing layouts of the instructions. All layouts must be assigned
- // by the LayoutAssignment pass, except for those on infeeds, parameters,
- // and the computation result. The latter two are specified in
- // computation_layout, so we only need to keep the existing layouts for
- // infeeds. Clearing the layouts here avoids hiding potential bugs in the
- // layout assignment pass that may accidentally use the existing layout.
+ // by the LayoutAssignment pass, except for those on parameters, the
+ // computation result, and a couple special cases. The former two are
+ // specified in computation_layout. Clearing the layouts here avoids hiding
+ // potential bugs in the layout assignment pass that may accidentally use the
+ // existing layout.
for (HloInstruction* instruction : computation->instructions()) {
if (instruction->opcode() == HloOpcode::kBitcast) {
// bitcasts are inherently layout sensitive and so a bitcast instruction
@@ -1567,7 +1555,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
"Unexpected bitcast operation seen during layout assignment: %s.",
instruction->ToString());
}
- if (instruction->opcode() != HloOpcode::kInfeed) {
+ // Some instructions carry mandatory layouts in their shape.
+ if (instruction->opcode() != HloOpcode::kInfeed &&
+ !IsLayoutConstrainedCustomCall(instruction)) {
LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
@@ -1802,6 +1792,18 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
}
TF_RETURN_IF_ERROR(Init());
+ // Verify computation layout is sane.
+ const HloComputation* entry = module->entry_computation();
+ TF_RET_CHECK(entry_computation_layout_->parameter_count() ==
+ entry->num_parameters());
+ for (int64 i = 0; i < entry->num_parameters(); ++i) {
+ TF_RET_CHECK(
+ ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i),
+ entry->parameter_instruction(i)->shape()));
+ }
+ TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(),
+ entry->root_instruction()->shape()));
+
// We do two passes. The first one we pass a nullptr ComputationLayout to
// the RunOnComputation() calls (for non entry computations), and we register
// the ComputationLayout which are naturally flowing in DFS fashion to the
@@ -1873,7 +1875,6 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kAllToAll:
case HloOpcode::kCollectivePermute:
- case HloOpcode::kCustomCall:
case HloOpcode::kDivide:
case HloOpcode::kDynamicSlice:
case HloOpcode::kDynamicUpdateSlice:
@@ -1930,6 +1931,7 @@ bool LayoutAssignment::InstructionCanChangeLayout(
case HloOpcode::kConstant:
case HloOpcode::kConvolution:
case HloOpcode::kCopy:
+ case HloOpcode::kCustomCall:
case HloOpcode::kDomain:
case HloOpcode::kDot:
case HloOpcode::kFusion:
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 2d48e12263..cb56f4cd19 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -333,19 +333,6 @@ class LayoutAssignment : public HloModulePass {
const ResultLayoutConstraint& layout_constraint,
LayoutConstraints* constraints);
- // By default LayoutAssignment ensures that inputs and outputs of CustomCalls
- // have the "major-first" layout (i.e. {n, n-1, ..., 0}).
- //
- // If this function returns true, LayoutAssignment does not set a layout for
- // the given CustomCall. It's up to the backend to set one in
- // AddBackendConstraints, if necessary.
- //
- // Precondition: instruction->opcode() == HloOpcode::kCustomCall.
- virtual bool CustomCallRequiresMajorFirstLayout(
- const HloInstruction* /*instruction*/) {
- return true;
- }
-
// Called after layouts of an instruction have been finalized to allow
// subclasses to check for platform specific assumptions.
virtual Status Verify(const HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 2c549cd872..ff6fdb5e4a 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -65,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase {
FindInstruction(module, name)->shape().layout().minor_to_major();
return std::vector<int64>(minor_to_major.begin(), minor_to_major.end());
}
+
+ void ExpectLayoutIs(const Shape& shape,
+ absl::Span<const int64> minor_to_major) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected))
+ << "Expected layout " << expected << ", actual " << shape.layout();
+ }
+
+ void ExpectTupleLayoutIs(
+ const Shape& shape,
+ std::initializer_list<absl::Span<const int64>> minor_to_majors) {
+ int i = 0;
+ for (const absl::Span<const int64> minor_to_major : minor_to_majors) {
+ const Layout expected = LayoutUtil::MakeLayout(minor_to_major);
+ const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout();
+ EXPECT_TRUE(LayoutUtil::Equal(actual, expected))
+ << "Expected tuple element " << i << " layout " << expected
+ << ", actual " << actual;
+ ++i;
+ }
+ }
};
TEST_F(LayoutAssignmentTest, ComputationLayout) {
@@ -1102,5 +1123,174 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) {
EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0));
}
+TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallNotLayoutConstrained
+
+ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] {
+ %p = f32[42,2,3] parameter(0)
+ ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz"
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1});
+ }
+ {
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ASSERT_THAT(root, op::CustomCall(op::Parameter()));
+ ExpectLayoutIs(root->shape(), {0, 2, 3, 1});
+ ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2});
+ }
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrained
+
+ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ // The custom call should be partially encapsulated in kCopy instructions
+ // because of the layout mismatches.
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Copy(), op::Parameter())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1});
+ ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedZeroOperands
+
+ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] {
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall()));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleOperand
+
+ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] {
+ %p0 = f32[4,4] parameter(0)
+ %p1 = f32[2,3] parameter(1)
+ %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1)
+ ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})}
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_parameter_layout(1) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0}));
+ *computation_layout.mutable_result_layout() = ShapeLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ HloInstruction* root = module->entry_computation()->root_instruction();
+ ExpectLayoutIs(root->shape(), {2, 1, 0, 3});
+
+ ASSERT_THAT(module->entry_computation()->root_instruction(),
+ op::Copy(op::CustomCall(op::Tuple())));
+
+ const HloInstruction* custom_call =
+ module->entry_computation()->root_instruction()->operand(0);
+ ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1});
+ ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}});
+}
+
+TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) {
+ const char* module_str = R"(
+HloModule CustomCallLayoutConstrainedTupleResult
+
+ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) {
+ %p0 = f32[4,4] parameter(0)
+ ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}}
+}
+)";
+ // Try with a couple different layouts. In each case the custom calls operand
+ // and result layout should match that of the computation.
+ TF_ASSERT_OK_AND_ASSIGN(
+ std::unique_ptr<VerifiedHloModule> module,
+ ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest()));
+ ComputationLayout computation_layout = module->entry_computation_layout();
+ *computation_layout.mutable_parameter_layout(0) =
+ ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}));
+ *computation_layout.mutable_result_layout() =
+ ShapeLayout(ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}),
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})}));
+ AssignLayouts(module.get(), &computation_layout);
+
+ ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(),
+ {{1, 0}, {1, 0}});
+
+ const HloInstruction* custom_call =
+ FindInstruction(module.get(), "custom-call");
+ ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}});
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index d244923532..7f0201942b 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -1645,7 +1645,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape,
}
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
- out << ShapeUtil::HumanString(shape);
+ out << ShapeUtil::HumanStringWithLayout(shape);
return out;
}
diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc
index a693fa3595..001490c6a8 100644
--- a/tensorflow/compiler/xla/tests/custom_call_test.cc
+++ b/tensorflow/compiler/xla/tests/custom_call_test.cc
@@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) {
LiteralTestUtil::ExpectR0Near<float>(10.0f, result, error_spec_);
}
-XLA_TEST_F(CustomCallTest,
- DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) {
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) {
auto module = CreateNewModule();
auto b = HloComputation::Builder(TestName());
@@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest,
Array3D<float>{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result);
}
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) {
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+ b.AddInstruction(
+ HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues"));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ // Note, the expected result is transposed! This is because the input and
+ // output layouts of the custom call differ and the called function just
+ // blindly adds one to each element.
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 4.f}, {3.f, 5.f}}, result);
+}
+
+XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) {
+ // The argument and result of the computation are set to different layouts,
+ // but the custom call is layout constrained to a fixed operand and result
+ // layout, so the correct result should be produced.
+ auto module = CreateNewModule();
+ auto b = HloComputation::Builder(TestName());
+
+ auto input =
+ b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p"));
+
+ const Shape& r2f32_dim0_major =
+ ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0});
+ b.AddInstruction(HloInstruction::CreateCustomCall(
+ r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major}));
+
+ module->AddEntryComputation(b.Build());
+ ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0}));
+ ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1}));
+
+ Literal argument = LiteralUtil::CreateR2<float>({{1.f, 2.f}, {3.f, 4.f}});
+
+ Literal result = ExecuteAndTransfer(std::move(module), {&argument});
+ LiteralTestUtil::ExpectR2Equal<float>({{2.f, 3.f}, {4.f, 5.f}}, result);
+}
+
class CustomCallClientAPITest : public ClientLibraryTestBase {};
// When using the client API, CustomCall targets can't begin with '$' -- these