From 396a8a4105edd409d0821c4d5d0b920b315ffb72 Mon Sep 17 00:00:00 2001 From: Mark Heffernan Date: Mon, 8 Oct 2018 14:26:43 -0700 Subject: Add custom call with layout constraints. Add a variant of CustomCall which specifies arbitrary layout constraints on the operands and result. The existing non-layout-constrained CustomCall is changed to have no layout preference and can now be assigned arbitrary layouts by layout assignment. PiperOrigin-RevId: 216249615 --- .../compiler/tf2xla/kernels/index_ops_cpu.cc | 22 ++- tensorflow/compiler/xla/client/xla_builder.cc | 43 ++++- tensorflow/compiler/xla/client/xla_builder.h | 22 ++- tensorflow/compiler/xla/layout_util.cc | 6 + tensorflow/compiler/xla/layout_util.h | 4 + .../xla/service/gpu/gpu_layout_assignment.cc | 10 -- .../xla/service/gpu/gpu_layout_assignment.h | 2 - tensorflow/compiler/xla/service/hlo.proto | 9 +- tensorflow/compiler/xla/service/hlo_instruction.cc | 28 ++- tensorflow/compiler/xla/service/hlo_instruction.h | 10 ++ .../compiler/xla/service/hlo_instructions.cc | 33 +++- tensorflow/compiler/xla/service/hlo_instructions.h | 32 +++- tensorflow/compiler/xla/service/hlo_parser.cc | 101 ++++++++--- tensorflow/compiler/xla/service/hlo_parser_test.cc | 67 ++++++++ tensorflow/compiler/xla/service/hlo_verifier.cc | 22 ++- .../compiler/xla/service/layout_assignment.cc | 108 ++++++------ .../compiler/xla/service/layout_assignment.h | 13 -- .../compiler/xla/service/layout_assignment_test.cc | 190 +++++++++++++++++++++ tensorflow/compiler/xla/shape_util.cc | 2 +- tensorflow/compiler/xla/tests/custom_call_test.cc | 50 +++++- 20 files changed, 650 insertions(+), 124 deletions(-) (limited to 'tensorflow/compiler') diff --git a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc index 3d81ae9eb8..f210bfbd88 100644 --- a/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc +++ b/tensorflow/compiler/tf2xla/kernels/index_ops_cpu.cc @@ -88,20 +88,30 @@ class ArgMaxCustomCallOp : public XlaOpKernel { xla::ConstantLiteral(&b, xla::LiteralUtil::CreateR0(dim))); } - xla::Shape xla_shape = - xla::ShapeUtil::MakeShape(xla::S64, output_shape.dim_sizes()); + // The argmax function expects row-major layout. + xla::Shape xla_shape = xla::ShapeUtil::MakeShapeWithDescendingLayout( + xla::S64, output_shape.dim_sizes()); + std::vector arg_shapes; + for (const xla::XlaOp& arg : args) { + auto shape_status = b.GetShape(arg); + OP_REQUIRES_OK(ctx, shape_status.status()); + xla::Shape arg_shape = shape_status.ConsumeValueOrDie(); + *arg_shape.mutable_layout() = xla::LayoutUtil::MakeDescendingLayout( + xla::ShapeUtil::Rank(arg_shape)); + arg_shapes.push_back(std::move(arg_shape)); + } // Tell XLA to call the custom code, defined in // index_ops_kernel_argmax_float_1d.cc. xla::XlaOp output; switch (input_shape.dims()) { case 1: - output = - xla::CustomCall(&b, "argmax_float_1d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_1d_xla_impl", args, + xla_shape, arg_shapes); break; case 2: - output = - xla::CustomCall(&b, "argmax_float_2d_xla_impl", args, xla_shape); + output = xla::CustomCallWithLayout(&b, "argmax_float_2d_xla_impl", args, + xla_shape, arg_shapes); break; default: OP_REQUIRES(ctx, false, diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc index 6b31831010..e7cf9ae363 100644 --- a/tensorflow/compiler/xla/client/xla_builder.cc +++ b/tensorflow/compiler/xla/client/xla_builder.cc @@ -1279,9 +1279,10 @@ XlaOp XlaBuilder::AfterAll(absl::Span tokens) { }); } -XlaOp XlaBuilder::CustomCall(const string& call_target_name, - absl::Span operands, - const Shape& shape, const string& opaque) { +XlaOp XlaBuilder::CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape, const string& opaque, + absl::optional> operand_shapes_with_layout) { return ReportErrorOrReturn([&]() -> StatusOr { HloInstructionProto instr; if (absl::StartsWith(call_target_name, "$")) { @@ -1293,6 +1294,31 @@ XlaOp XlaBuilder::CustomCall(const string& call_target_name, *instr.mutable_shape() = shape; instr.set_custom_call_target(call_target_name); instr.set_custom_call_opaque(opaque); + if (operand_shapes_with_layout.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return InvalidArgument( + "Result shape must have layout for custom call with constrained " + "layout."); + } + if (operands.size() != operand_shapes_with_layout->size()) { + return InvalidArgument( + "Must specify a shape with layout for each operand for custom call " + "with constrained layout; given %d shapes, expected %d", + operand_shapes_with_layout->size(), operands.size()); + } + instr.set_constrain_layout(true); + int64 operand_num = 0; + for (const Shape& operand_shape : *operand_shapes_with_layout) { + if (!LayoutUtil::HasLayout(operand_shape)) { + return InvalidArgument( + "No layout specified for operand %d for custom call with " + "constrained layout.", + operand_num); + } + *instr.add_operand_shapes_with_layout() = operand_shape; + ++operand_num; + } + } return AddInstruction(std::move(instr), HloOpcode::kCustomCall, operands); }); } @@ -2690,7 +2716,16 @@ XlaOp Call(XlaBuilder* builder, const XlaComputation& computation, XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque) { - return builder->CustomCall(call_target_name, operands, shape, opaque); + return builder->CustomCall(call_target_name, operands, shape, opaque, + /*operand_shapes_with_layout=*/absl::nullopt); +} + +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape, + absl::Span operand_shapes_with_layout, + const string& opaque) { + return builder->CustomCall(call_target_name, operands, shape, opaque, + operand_shapes_with_layout); } XlaOp Complex(const XlaOp& real, const XlaOp& imag, diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h index 2e14e47a35..9ceede7a79 100644 --- a/tensorflow/compiler/xla/client/xla_builder.h +++ b/tensorflow/compiler/xla/client/xla_builder.h @@ -577,9 +577,10 @@ class XlaBuilder { absl::Span operands); // Enqueues a custom call instruction onto the computation. - XlaOp CustomCall(const string& call_target_name, - absl::Span operands, const Shape& shape, - const string& opaque); + XlaOp CustomCall( + const string& call_target_name, absl::Span operands, + const Shape& shape_with_layout, const string& opaque, + absl::optional> operand_shapes_with_layout); // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one @@ -1197,6 +1198,10 @@ class XlaBuilder { friend XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque); + friend XlaOp CustomCallWithLayout( + XlaBuilder* builder, const string& call_target_name, + absl::Span operands, const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, const string& opaque); friend XlaOp Complex(const XlaOp& real, const XlaOp& imag, absl::Span broadcast_dimensions); friend XlaOp Conj(const XlaOp& operand); @@ -1732,6 +1737,17 @@ XlaOp CustomCall(XlaBuilder* builder, const string& call_target_name, absl::Span operands, const Shape& shape, const string& opaque = ""); +// Overload which constructs a custom call with fixed layouts. The operands will +// have the layouts specified by |operand_shapes_with_layout| when provided to +// external code, and the external code is expected to produce a result with the +// layout specified by |shape_with_layout|. All shapes in |shape_with_layout| +// and |operand_shapes_with_layout| must have layouts. +XlaOp CustomCallWithLayout(XlaBuilder* builder, const string& call_target_name, + absl::Span operands, + const Shape& shape_with_layout, + absl::Span operand_shapes_with_layout, + const string& opaque = ""); + // The following methods enqueue element-wise binary arithmetic operations // onto the computation. The shapes of the operands have to match unless one // of the operands is a scalar, or an explicit broadcast dimension is given diff --git a/tensorflow/compiler/xla/layout_util.cc b/tensorflow/compiler/xla/layout_util.cc index d310335618..3c8db9aa45 100644 --- a/tensorflow/compiler/xla/layout_util.cc +++ b/tensorflow/compiler/xla/layout_util.cc @@ -65,6 +65,12 @@ void SetDefaultLayoutToContainer( return layout; } +/* static */ Layout LayoutUtil::MakeDescendingLayout(int64 rank) { + std::vector layout(rank); + std::iota(layout.rbegin(), layout.rend(), static_cast(0)); + return MakeLayout(layout); +} + /* static */ Layout LayoutUtil::MakeLayoutFromMajorToMinor( absl::Span major_to_minor) { Layout layout; diff --git a/tensorflow/compiler/xla/layout_util.h b/tensorflow/compiler/xla/layout_util.h index b78883c2d8..af032b1cae 100644 --- a/tensorflow/compiler/xla/layout_util.h +++ b/tensorflow/compiler/xla/layout_util.h @@ -40,6 +40,10 @@ class LayoutUtil { static Layout MakeLayoutFromMajorToMinor( absl::Span major_to_minor); + // Returns a layout with descending ((i.e. {n, n-1, ..., 0}) minor-to-major + // dimensions. + static Layout MakeDescendingLayout(int64 rank); + // Creates a sparse layout with the given maximum number of elements. (This is // a convenience function for protobuf construction.) static Layout MakeSparseLayout(int64 max_sparse_elements); diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc index 1ffe855750..8c9a8adc61 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc @@ -213,16 +213,6 @@ Status GpuLayoutAssignment::AddBackendConstraints( return Status::OK(); } -bool GpuLayoutAssignment::CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) { - // - Inputs to cudnn batchnorm custom calls don't need the major-first layout - // (i.e. {n, n-1, ...0}) -- we can handle any layout. - // - Inputs to cudnn convolution require custom layouts handled in - // AddBackendConstraints. - return !IsCustomCallToDnnBatchNorm(*instruction) && - !IsCustomCallToDnnConvolution(*instruction); -} - Status GpuLayoutAssignment::PropagateOperandConstraint( const OperandLayoutConstraint& layout_constraint, LayoutConstraints* constraints) { diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h index 4ba7989e9c..6a48e55fd2 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h +++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.h @@ -46,8 +46,6 @@ class GpuLayoutAssignment : public LayoutAssignment { Status PropagateBufferConstraint( const BufferLayoutConstraint& buffer_constraint, LayoutConstraints* constraints) override; - bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* instruction) override; private: Status AddBackendConstraintsToDnnConvCustomCall( diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto index 1ea26ddd5b..a0eb9e6ddc 100644 --- a/tensorflow/compiler/xla/service/hlo.proto +++ b/tensorflow/compiler/xla/service/hlo.proto @@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto"; option cc_enable_arenas = true; // Serialization of HloInstruction. -// Next ID: 56 +// Next ID: 58 message HloInstructionProto { reserved 10; reserved "parameter_name"; @@ -184,6 +184,13 @@ message HloInstructionProto { // Sharding for kDomain instructions. xla.OpSharding domain_entry_sharding = 54; xla.OpSharding domain_exit_sharding = 55; + + // For custom call this indicates that the layouts are constrained. If + // constrain_layout is true then the 'shape' field must contain a layout, and + // 'operand_shapes_with_layout' must contain a shape with layout for each + // operand. + bool constrain_layout = 56; + repeated Shape operand_shapes_with_layout = 57; } // Serialization of HloComputation. diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 2f6db7cd7c..5c3908a9a4 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -396,9 +396,22 @@ StatusOr> HloInstruction::CreateFromProto( operands(1), operands(2), computations(1)); break; case HloOpcode::kCustomCall: - instruction = CreateCustomCall(proto.shape(), all_operands(), - proto.custom_call_target(), - proto.custom_call_opaque()); + if (proto.constrain_layout()) { + // A proto RepeatedPtrField cannot be converted to a Span (it is a + // vector of pointers essentially) so create a vector of shapes to pass + // in. + std::vector operand_shapes; + for (const Shape& shape : proto.operand_shapes_with_layout()) { + operand_shapes.push_back(shape); + } + instruction = CreateCustomCall( + proto.shape(), all_operands(), proto.custom_call_target(), + operand_shapes, proto.custom_call_opaque()); + } else { + instruction = CreateCustomCall(proto.shape(), all_operands(), + proto.custom_call_target(), + proto.custom_call_opaque()); + } if (proto.has_window()) { static_cast(instruction.get()) ->set_window(proto.window()); @@ -1142,6 +1155,15 @@ bool HloInstruction::HasSideEffect() const { shape, operands, custom_call_target, opaque); } +/* static */ std::unique_ptr HloInstruction::CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque) { + return absl::make_unique( + shape, operands, custom_call_target, opaque, operand_shapes_with_layout); +} + /* static */ std::unique_ptr HloInstruction::CreateTuple( absl::Span elements) { std::vector element_shapes; diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 374862c4b6..44f776ebac 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -734,6 +734,16 @@ class HloInstruction { const Shape& shape, absl::Span operands, absl::string_view custom_call_target, absl::string_view opaque = ""); + // Overload which constrains the layouts of the operand and result. 'shape' + // and 'operand_shapes_with_layout' must have layouts. + // 'operand_shapes_with_layout' must have a compatible element for each + // operand. + static std::unique_ptr CreateCustomCall( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, + absl::Span operand_shapes_with_layout, + absl::string_view opaque = ""); + // Creates a tuple instruction with the given elements. This is a convenience // wrapper around CreateVariadic. static std::unique_ptr CreateTuple( diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc index 152d8eacdb..2ec233eaec 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.cc +++ b/tensorflow/compiler/xla/service/hlo_instructions.cc @@ -1825,7 +1825,24 @@ HloCustomCallInstruction::HloCustomCallInstruction( : HloInstruction(HloOpcode::kCustomCall, shape), custom_call_target_(custom_call_target.begin(), custom_call_target.end()), opaque_(opaque.begin(), opaque.end()), - feature_group_count_(1) { + feature_group_count_(1), + layout_constrained_(false) { + for (auto operand : operands) { + AppendOperand(operand); + } +} + +HloCustomCallInstruction::HloCustomCallInstruction( + const Shape& shape, absl::Span operands, + absl::string_view custom_call_target, absl::string_view opaque, + absl::Span operand_shapes_with_layout) + : HloInstruction(HloOpcode::kCustomCall, shape), + custom_call_target_(custom_call_target.begin(), custom_call_target.end()), + opaque_(opaque.begin(), opaque.end()), + feature_group_count_(1), + layout_constrained_(true), + operand_shapes_with_layout_(operand_shapes_with_layout.begin(), + operand_shapes_with_layout.end()) { for (auto operand : operands) { AppendOperand(operand); } @@ -1843,6 +1860,12 @@ HloInstructionProto HloCustomCallInstruction::ToProto() const { proto.set_custom_call_target(custom_call_target_); proto.set_custom_call_opaque(opaque_); proto.set_feature_group_count(feature_group_count_); + if (layout_constrained()) { + proto.set_constrain_layout(true); + for (const Shape& shape : operand_shapes_with_layout_) { + *proto.add_operand_shapes_with_layout() = shape; + } + } return proto; } @@ -1870,6 +1893,14 @@ std::vector HloCustomCallInstruction::ExtraAttributesToStringImpl( if (!opaque_.empty()) { extra.push_back(StrCat("opaque=\"", CEscape(opaque_), "\"")); } + if (layout_constrained()) { + std::vector shape_strings; + for (const Shape& shape : operand_shapes_with_layout_) { + shape_strings.push_back(ShapeUtil::HumanStringWithLayout(shape)); + } + extra.push_back(StrCat("operand_layout_constraints={", + StrJoin(shape_strings, ", "), "}")); + } return extra; } diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h index e169604072..4c5fc759a3 100644 --- a/tensorflow/compiler/xla/service/hlo_instructions.h +++ b/tensorflow/compiler/xla/service/hlo_instructions.h @@ -1053,10 +1053,19 @@ class HloSelectAndScatterInstruction : public HloInstruction { class HloCustomCallInstruction : public HloInstruction { public: - explicit HloCustomCallInstruction(const Shape& shape, - absl::Span operands, - absl::string_view custom_call_target, - absl::string_view opaque); + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque); + + // Constructor for a custom call with constrained layout. 'shape' and + // 'operands_with_layout' must all have layouts. + HloCustomCallInstruction(const Shape& shape, + absl::Span operands, + absl::string_view custom_call_target, + absl::string_view opaque, + absl::Span operand_shapes_with_layout); + const Window& window() const override { CHECK(window_ != nullptr); return *window_; @@ -1085,6 +1094,16 @@ class HloCustomCallInstruction : public HloInstruction { // Returns a serialized representation of this instruction. HloInstructionProto ToProto() const override; + // Returns whether the result and operand layouts are constrained. + bool layout_constrained() const { return layout_constrained_; } + + // Returns the shapes (with layout) of the operands. CHECKs if this custom + // call does not have constrained layouts. + const std::vector& operand_shapes_with_layout() const { + CHECK(layout_constrained()); + return operand_shapes_with_layout_; + } + private: std::vector ExtraAttributesToStringImpl( const HloPrintOptions& options) const override; @@ -1106,6 +1125,11 @@ class HloCustomCallInstruction : public HloInstruction { std::unique_ptr convolution_dimension_numbers_; // The number of feature groups. This is used for grouped convolutions. int64 feature_group_count_; + // Whether the result and operand layouts are constrained. + bool layout_constrained_; + // For layout-constrained custom calls, this vector holds the shape with + // layout for each operand. + std::vector operand_shapes_with_layout_; }; class HloPadInstruction : public HloInstruction { diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index dd62988bcc..96f9ff6654 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -174,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -240,6 +241,7 @@ class HloParser { bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector* result); + bool ParseShapeList(std::vector* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result); @@ -1341,6 +1343,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, optional window; optional dnums; optional feature_group_count; + optional> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; @@ -1349,12 +1352,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, - opaque.has_value() ? *opaque : "")); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } @@ -2533,6 +2576,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2825,6 +2877,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2832,23 +2901,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index 255123d331..17538c05bc 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -802,6 +802,43 @@ ENTRY %ConstantUnsignedNoOverflow () -> u64[] { ROOT %constant = u64[] constant(9223372036854775807) } +)" +}, +// CustomCallWithLayoutConstraints +{ +"CustomCallWithLayoutConstraints", +R"(HloModule CustomCallWithLayoutConstraints + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[123,4]{1,0}} +} + +)" +}, +// CustomCallWithLayoutConstraintsNoOperands +{ +"CustomCallWithLayoutConstraintsNoOperands", +R"(HloModule CustomCallWithLayoutConstraintsNoOperands + +ENTRY %CustomCallWithLayoutConstraints () -> f32[1,2,3] { + ROOT %custom-call = f32[1,2,3]{0,2,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} + +)" +}, +// CustomCallWithLayoutConstraintsTupleShapes +{ +"CustomCallWithLayoutConstraintsTupleShapes", +R"(HloModule CustomCallWithLayoutConstraintsTupleShapes + +ENTRY %CustomCallWithLayoutConstraints (p0: (f32[2,2], f32[42,2,3]), p1: f32[123,4]) -> (f32[1,2,3], f32[1,2,3]) { + %p0 = (f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = (f32[1,2,3]{0,2,1}, f32[1,2,3]{1,2,0}) custom-call((f32[2,2]{0,1}, f32[42,2,3]{0,1,2}) %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={(f32[2,2]{1,0}, f32[42,2,3]{2,0,1}), f32[123,4]{1,0}} +} + )" }, }); @@ -2069,5 +2106,35 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { op::Broadcast(), op::Multiply(), op::Add())); } +TEST_F(HloParserTest, CustomCallWrongNumberofOperandConstraints) { + const string original = R"(HloModule CustomCallWrongNumberofOperandConstraints + +ENTRY %CustomCallWrongNumberofOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "Expected 2 operand layout constraints, 1 given"); +} + +TEST_F(HloParserTest, CustomCallIncompatibleOperandConstraints) { + const string original = R"(HloModule CustomCallIncompatibleOperandConstraints + +ENTRY %CustomCallIncompatibleOperandConstraints (p0: f32[42,2,3], p1: f32[123,4]) -> f32[1,2,3] { + %p0 = f32[42,2,3]{0,1,2} parameter(0) + %p1 = f32[123,4]{0,1} parameter(1) + ROOT %custom-call = f32[1,2,3]{0,1,2} custom-call(f32[42,2,3]{0,1,2} %p0, f32[123,4]{0,1} %p1), custom_call_target="baz", operand_layout_constraints={f32[42,2,3]{0,1,2}, f32[555,5]{1,0}} +} + +)"; + ExpectHasSubstr(ParseHloString(original).status().error_message(), + "operand 1 is not compatible with operand shape"); +} + +// custom call incompatible shape. + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc index 496fe1795d..be3bee5975 100644 --- a/tensorflow/compiler/xla/service/hlo_verifier.cc +++ b/tensorflow/compiler/xla/service/hlo_verifier.cc @@ -360,7 +360,27 @@ Status ShapeVerifier::HandleCall(HloInstruction* call) { return CheckShape(call, call->to_apply()->root_instruction()->shape()); } -Status ShapeVerifier::HandleCustomCall(HloInstruction*) { return Status::OK(); } +Status ShapeVerifier::HandleCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + TF_RET_CHECK(custom_call != nullptr); + if (custom_call->layout_constrained()) { + // If the layout is constrained, verify all the respective shapes have + // layouts and that the constrained operand shapes match the shapes of the + // operands. + TF_RET_CHECK(LayoutUtil::HasLayout(custom_call->shape())); + TF_RET_CHECK(custom_call->operand_count() == + custom_call->operand_shapes_with_layout().size()); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + const Shape& operand_shape_with_layout = + custom_call->operand_shapes_with_layout()[i]; + TF_RET_CHECK(ShapeUtil::Compatible(custom_call->operand(i)->shape(), + operand_shape_with_layout)); + TF_RET_CHECK(LayoutUtil::HasLayout(operand_shape_with_layout)); + } + } + return Status::OK(); +} Status ShapeVerifier::HandleSlice(HloInstruction* slice) { return CheckShape(slice, diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc index cc4a342e9d..ad65b147c1 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.cc +++ b/tensorflow/compiler/xla/service/layout_assignment.cc @@ -419,6 +419,16 @@ Status LayoutAssignment::BuildHostChannelConstraints( return Status::OK(); } +namespace { + +bool IsLayoutConstrainedCustomCall(HloInstruction* instruction) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + return custom_call != nullptr && custom_call->layout_constrained(); +} + +} // namespace + Status LayoutAssignment::AddMandatoryConstraints( const ComputationLayout* computation_layout, ChannelLayoutConstraints* channel_constraints, HloComputation* computation, @@ -434,7 +444,6 @@ Status LayoutAssignment::AddMandatoryConstraints( // Constrain layouts of instructions which define values with pre-existing // layouts. for (auto* instruction : computation->instructions()) { - Shape const* shape_with_layout = nullptr; if (instruction->opcode() == HloOpcode::kInfeed) { // Infeed layouts must match the layout of the original inserted // instruction. @@ -456,17 +465,21 @@ Status LayoutAssignment::AddMandatoryConstraints( if (parameter_layout.LayoutIsSet()) { // Parameter layouts must match the respective layout in // ComputationLayout, if there is one. - shape_with_layout = ¶meter_layout.shape(); + TF_RETURN_IF_ERROR(constraints->SetInstructionLayout( + parameter_layout.shape(), instruction)); } } - } - if (shape_with_layout != nullptr) { + } else if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(*shape_with_layout, instruction)); - } - - if (instruction->opcode() == HloOpcode::kSend || - instruction->opcode() == HloOpcode::kRecv) { + constraints->SetInstructionLayout(custom_call->shape(), custom_call)); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RETURN_IF_ERROR(constraints->SetOperandLayout( + custom_call->operand_shapes_with_layout()[i], custom_call, i)); + } + } else if (instruction->opcode() == HloOpcode::kSend || + instruction->opcode() == HloOpcode::kRecv) { CHECK(get_channel_constraints(instruction)) << "Multi-module layout assignment requires ChannelLayoutConstraints"; int64 channel_id = instruction->channel_id(); @@ -621,31 +634,6 @@ Status LayoutAssignment::AddMandatoryConstraints( TF_RETURN_IF_ERROR(constraints->SetOperandLayout( false_computation_layout.parameter_shape(0), instruction, 2, /*mandatory=*/true)); - } else if (instruction->opcode() == HloOpcode::kCustomCall) { - if (!CustomCallRequiresMajorFirstLayout(instruction)) { - continue; - } - // Add constraints for kCustomCall instruction operands and instructions. - // For now we only support major-first layouts for all inputs and outputs. - Shape result_shape = ShapeUtil::MakeShapeWithDescendingLayout( - instruction->shape().element_type(), - AsInt64Slice(instruction->shape().dimensions())); - TF_RETURN_IF_ERROR( - constraints->SetInstructionLayout(result_shape, instruction)); - for (int64 i = 0; i < instruction->operand_count(); ++i) { - const Shape& operand_shape = instruction->operand(i)->shape(); - // Opaque operands don't get a layout constraint. - if (ShapeUtil::IsOpaque(operand_shape)) { - continue; - } - - Shape row_major_operand_shape = - ShapeUtil::MakeShapeWithDescendingLayout( - operand_shape.element_type(), - AsInt64Slice(operand_shape.dimensions())); - TF_RETURN_IF_ERROR(constraints->SetOperandLayout( - row_major_operand_shape, instruction, i)); - } } } // Finally set the result layout to match ComputationLayout, if there is one. @@ -676,16 +664,18 @@ Status CheckCallLayout(HloInstruction* call, return Status::OK(); } -// Custom calls have fixed input and output layouts. -Status CheckCustomCallLayout(HloInstruction* custom_call) { - for (const HloInstruction* operand : custom_call->operands()) { - TF_RET_CHECK( - ShapeUtil::IsOpaque(operand->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(operand->shape().layout())); +// Operands of layout-constrained custom calls must match the expected +// constrained layouts. +Status CheckCustomCallLayout(HloInstruction* instruction) { + if (IsLayoutConstrainedCustomCall(instruction)) { + const HloCustomCallInstruction* custom_call = + DynCast(instruction); + for (int64 i = 0; i < custom_call->operand_count(); ++i) { + TF_RET_CHECK(LayoutUtil::LayoutsInShapesEqual( + custom_call->operand(i)->shape(), + custom_call->operand_shapes_with_layout()[i])); + } } - TF_RET_CHECK( - ShapeUtil::IsOpaque(custom_call->shape()) || - LayoutUtil::IsMonotonicWithDim0Major(custom_call->shape().layout())); return Status::OK(); } @@ -932,9 +922,7 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) { FindOrDie(computation_layouts_, instruction->to_apply()))); break; case HloOpcode::kCustomCall: - if (CustomCallRequiresMajorFirstLayout(instruction)) { - TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); - } + TF_RETURN_IF_ERROR(CheckCustomCallLayout(instruction)); break; case HloOpcode::kFusion: TF_RETURN_IF_ERROR(CheckFusionLayout(instruction)); @@ -1554,11 +1542,11 @@ Status LayoutAssignment::CalculateComputationLayout( Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { // Clear existing layouts of the instructions. All layouts must be assigned - // by the LayoutAssignment pass, except for those on infeeds, parameters, - // and the computation result. The latter two are specified in - // computation_layout, so we only need to keep the existing layouts for - // infeeds. Clearing the layouts here avoids hiding potential bugs in the - // layout assignment pass that may accidentally use the existing layout. + // by the LayoutAssignment pass, except for those on parameters, the + // computation result, and a couple special cases. The former two are + // specified in computation_layout. Clearing the layouts here avoids hiding + // potential bugs in the layout assignment pass that may accidentally use the + // existing layout. for (HloInstruction* instruction : computation->instructions()) { if (instruction->opcode() == HloOpcode::kBitcast) { // bitcasts are inherently layout sensitive and so a bitcast instruction @@ -1567,7 +1555,9 @@ Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) { "Unexpected bitcast operation seen during layout assignment: %s.", instruction->ToString()); } - if (instruction->opcode() != HloOpcode::kInfeed) { + // Some instructions carry mandatory layouts in their shape. + if (instruction->opcode() != HloOpcode::kInfeed && + !IsLayoutConstrainedCustomCall(instruction)) { LayoutUtil::ClearLayout(instruction->mutable_shape()); } } @@ -1802,6 +1792,18 @@ StatusOr LayoutAssignment::Run(HloModule* module) { } TF_RETURN_IF_ERROR(Init()); + // Verify computation layout is sane. + const HloComputation* entry = module->entry_computation(); + TF_RET_CHECK(entry_computation_layout_->parameter_count() == + entry->num_parameters()); + for (int64 i = 0; i < entry->num_parameters(); ++i) { + TF_RET_CHECK( + ShapeUtil::Compatible(entry_computation_layout_->parameter_shape(i), + entry->parameter_instruction(i)->shape())); + } + TF_RET_CHECK(ShapeUtil::Compatible(entry_computation_layout_->result_shape(), + entry->root_instruction()->shape())); + // We do two passes. The first one we pass a nullptr ComputationLayout to // the RunOnComputation() calls (for non entry computations), and we register // the ComputationLayout which are naturally flowing in DFS fashion to the @@ -1873,7 +1875,6 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kCrossReplicaSum: case HloOpcode::kAllToAll: case HloOpcode::kCollectivePermute: - case HloOpcode::kCustomCall: case HloOpcode::kDivide: case HloOpcode::kDynamicSlice: case HloOpcode::kDynamicUpdateSlice: @@ -1930,6 +1931,7 @@ bool LayoutAssignment::InstructionCanChangeLayout( case HloOpcode::kConstant: case HloOpcode::kConvolution: case HloOpcode::kCopy: + case HloOpcode::kCustomCall: case HloOpcode::kDomain: case HloOpcode::kDot: case HloOpcode::kFusion: diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h index 2d48e12263..cb56f4cd19 100644 --- a/tensorflow/compiler/xla/service/layout_assignment.h +++ b/tensorflow/compiler/xla/service/layout_assignment.h @@ -333,19 +333,6 @@ class LayoutAssignment : public HloModulePass { const ResultLayoutConstraint& layout_constraint, LayoutConstraints* constraints); - // By default LayoutAssignment ensures that inputs and outputs of CustomCalls - // have the "major-first" layout (i.e. {n, n-1, ..., 0}). - // - // If this function returns true, LayoutAssignment does not set a layout for - // the given CustomCall. It's up to the backend to set one in - // AddBackendConstraints, if necessary. - // - // Precondition: instruction->opcode() == HloOpcode::kCustomCall. - virtual bool CustomCallRequiresMajorFirstLayout( - const HloInstruction* /*instruction*/) { - return true; - } - // Called after layouts of an instruction have been finalized to allow // subclasses to check for platform specific assumptions. virtual Status Verify(const HloInstruction* instruction) { diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc index 2c549cd872..ff6fdb5e4a 100644 --- a/tensorflow/compiler/xla/service/layout_assignment_test.cc +++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc @@ -65,6 +65,27 @@ class LayoutAssignmentTest : public HloVerifiedTestBase { FindInstruction(module, name)->shape().layout().minor_to_major(); return std::vector(minor_to_major.begin(), minor_to_major.end()); } + + void ExpectLayoutIs(const Shape& shape, + absl::Span minor_to_major) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + EXPECT_TRUE(LayoutUtil::Equal(shape.layout(), expected)) + << "Expected layout " << expected << ", actual " << shape.layout(); + } + + void ExpectTupleLayoutIs( + const Shape& shape, + std::initializer_list> minor_to_majors) { + int i = 0; + for (const absl::Span minor_to_major : minor_to_majors) { + const Layout expected = LayoutUtil::MakeLayout(minor_to_major); + const Layout& actual = ShapeUtil::GetTupleElementShape(shape, i).layout(); + EXPECT_TRUE(LayoutUtil::Equal(actual, expected)) + << "Expected tuple element " << i << " layout " << expected + << ", actual " << actual; + ++i; + } + } }; TEST_F(LayoutAssignmentTest, ComputationLayout) { @@ -1102,5 +1123,174 @@ TEST_F(LayoutAssignmentTest, TupleCopyOnLayoutMismatch) { EXPECT_THAT(LayoutOf(&module(), "next_buf"), ElementsAre(1, 0)); } +TEST_F(LayoutAssignmentTest, CustomCallNotLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallNotLayoutConstrained + +ENTRY %CustomCallWithNotLayoutConstrained (p: f32[42,2,3]) -> f32[1,2,3,4] { + %p = f32[42,2,3] parameter(0) + ROOT %custom-call = f32[1,2,3,4] custom-call(f32[42,2,3] %p), custom_call_target="baz" +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 2, 1})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {3, 2, 0, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 2, 1}); + } + { + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {42, 2, 3}, {0, 1, 2})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {0, 2, 3, 1})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ASSERT_THAT(root, op::CustomCall(op::Parameter())); + ExpectLayoutIs(root->shape(), {0, 2, 3, 1}); + ExpectLayoutIs(root->operand(0)->shape(), {0, 1, 2}); + } +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrained) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrained + +ENTRY %CustomCallWithLayoutConstraints (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(f32[4,4] %p0, f32[2,3] %p1), custom_call_target="baz", operand_layout_constraints={f32[4,4]{0,1}, f32[2,3]{1,0}} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + // The custom call should be partially encapsulated in kCopy instructions + // because of the layout mismatches. + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Copy(), op::Parameter()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectLayoutIs(custom_call->operand(0)->shape(), {0, 1}); + ExpectLayoutIs(custom_call->operand(1)->shape(), {1, 0}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedZeroOperands) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedZeroOperands + +ENTRY %CustomCallLayoutConstrainedZeroOperands () -> f32[1,2,3,4] { + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(), custom_call_target="baz", operand_layout_constraints={} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall())); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleOperand) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleOperand + +ENTRY %CustomCallLayoutConstrainedTupleOperand (p0: f32[4,4], p1: f32[2,3]) -> f32[1,2,3,4] { + %p0 = f32[4,4] parameter(0) + %p1 = f32[2,3] parameter(1) + %tuple = (f32[4,4], f32[2,3]) tuple(%p0, %p1) + ROOT %custom-call = f32[1,2,3,4]{3,2,0,1} custom-call(%tuple), custom_call_target="baz", operand_layout_constraints={(f32[4,4]{1,0}, f32[2,3]{0,1})} +} +)"; + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_parameter_layout(1) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})); + *computation_layout.mutable_result_layout() = ShapeLayout( + ShapeUtil::MakeShapeWithLayout(F32, {1, 2, 3, 4}, {2, 1, 0, 3})); + AssignLayouts(module.get(), &computation_layout); + + HloInstruction* root = module->entry_computation()->root_instruction(); + ExpectLayoutIs(root->shape(), {2, 1, 0, 3}); + + ASSERT_THAT(module->entry_computation()->root_instruction(), + op::Copy(op::CustomCall(op::Tuple()))); + + const HloInstruction* custom_call = + module->entry_computation()->root_instruction()->operand(0); + ExpectLayoutIs(custom_call->shape(), {3, 2, 0, 1}); + ExpectTupleLayoutIs(custom_call->operand(0)->shape(), {{1, 0}, {0, 1}}); +} + +TEST_F(LayoutAssignmentTest, CustomCallLayoutConstrainedTupleResult) { + const char* module_str = R"( +HloModule CustomCallLayoutConstrainedTupleResult + +ENTRY %CustomCallLayoutConstrainedTupleResult (p0: f32[4,4]) -> (f32[4,4]{1,0}, f32[2,3]{0,1}) { + %p0 = f32[4,4] parameter(0) + ROOT %custom-call = (f32[4,4]{1,0}, f32[2,3]{0,1}) custom-call(%p0), custom_call_target="baz", operand_layout_constraints={f32[4,4]{1,0}} +} +)"; + // Try with a couple different layouts. In each case the custom calls operand + // and result layout should match that of the computation. + TF_ASSERT_OK_AND_ASSIGN( + std::unique_ptr module, + ParseAndReturnVerifiedModule(module_str, GetModuleConfigForTest())); + ComputationLayout computation_layout = module->entry_computation_layout(); + *computation_layout.mutable_parameter_layout(0) = + ShapeLayout(ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0})); + *computation_layout.mutable_result_layout() = + ShapeLayout(ShapeUtil::MakeTupleShape( + {ShapeUtil::MakeShapeWithLayout(F32, {4, 4}, {1, 0}), + ShapeUtil::MakeShapeWithLayout(F32, {2, 3}, {1, 0})})); + AssignLayouts(module.get(), &computation_layout); + + ExpectTupleLayoutIs(module->entry_computation()->root_instruction()->shape(), + {{1, 0}, {1, 0}}); + + const HloInstruction* custom_call = + FindInstruction(module.get(), "custom-call"); + ExpectTupleLayoutIs(custom_call->shape(), {{1, 0}, {0, 1}}); +} + } // namespace } // namespace xla diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc index d244923532..7f0201942b 100644 --- a/tensorflow/compiler/xla/shape_util.cc +++ b/tensorflow/compiler/xla/shape_util.cc @@ -1645,7 +1645,7 @@ ShapeUtil::DimensionsUnmodifiedByReshape(const Shape& input_shape, } std::ostream& operator<<(std::ostream& out, const Shape& shape) { - out << ShapeUtil::HumanString(shape); + out << ShapeUtil::HumanStringWithLayout(shape); return out; } diff --git a/tensorflow/compiler/xla/tests/custom_call_test.cc b/tensorflow/compiler/xla/tests/custom_call_test.cc index a693fa3595..001490c6a8 100644 --- a/tensorflow/compiler/xla/tests/custom_call_test.cc +++ b/tensorflow/compiler/xla/tests/custom_call_test.cc @@ -105,8 +105,7 @@ XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(CustomCallR2F32Reduce)) { LiteralTestUtil::ExpectR0Near(10.0f, result, error_spec_); } -XLA_TEST_F(CustomCallTest, - DISABLED_ON_GPU(CustomCall_UsedInOtherComputations)) { +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(UsedInOtherComputations)) { auto module = CreateNewModule(); auto b = HloComputation::Builder(TestName()); @@ -130,6 +129,53 @@ XLA_TEST_F(CustomCallTest, Array3D{{{2, 3}, {4, 5}}, {{3, 4}, {5, 6}}}, result); } +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(InputAndOutputLayoutDiffer)) { + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + b.AddInstruction( + HloInstruction::CreateCustomCall(r2f32_, {input}, "Add1ToValues")); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + // Note, the expected result is transposed! This is because the input and + // output layouts of the custom call differ and the called function just + // blindly adds one to each element. + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 4.f}, {3.f, 5.f}}, result); +} + +XLA_TEST_F(CustomCallTest, DISABLED_ON_GPU(LayoutConstrained)) { + // The argument and result of the computation are set to different layouts, + // but the custom call is layout constrained to a fixed operand and result + // layout, so the correct result should be produced. + auto module = CreateNewModule(); + auto b = HloComputation::Builder(TestName()); + + auto input = + b.AddInstruction(HloInstruction::CreateParameter(0, r2f32_, "p")); + + const Shape& r2f32_dim0_major = + ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0}); + b.AddInstruction(HloInstruction::CreateCustomCall( + r2f32_dim0_major, {input}, "Add1ToValues", {r2f32_dim0_major})); + + module->AddEntryComputation(b.Build()); + ForceParameterLayout(module.get(), 0, LayoutUtil::MakeLayout({1, 0})); + ForceResultLayout(module.get(), LayoutUtil::MakeLayout({0, 1})); + + Literal argument = LiteralUtil::CreateR2({{1.f, 2.f}, {3.f, 4.f}}); + + Literal result = ExecuteAndTransfer(std::move(module), {&argument}); + LiteralTestUtil::ExpectR2Equal({{2.f, 3.f}, {4.f, 5.f}}, result); +} + class CustomCallClientAPITest : public ClientLibraryTestBase {}; // When using the client API, CustomCall targets can't begin with '$' -- these -- cgit v1.2.3