diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 101 |
1 files changed, 81 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index dd62988bcc..96f9ff6654 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -174,6 +174,7 @@ class HloParser { kDistribution, kDomain, kPrecisionList, + kShapeList }; struct AttrConfig { @@ -240,6 +241,7 @@ class HloParser { bool ParseSliceRanges(SliceRanges* result); bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result); + bool ParseShapeList(std::vector<Shape>* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); @@ -1341,6 +1343,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, optional<Window> window; optional<ConvolutionDimensionNumbers> dnums; optional<int64> feature_group_count; + optional<std::vector<Shape>> operand_layout_constraints; attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString, &custom_call_target}; attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque}; @@ -1349,12 +1352,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, AttrTy::kConvolutionDimensionNumbers, &dnums}; attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64, &feature_group_count}; + attrs["operand_layout_constraints"] = { + /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints}; if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { return false; } - instruction = builder->AddInstruction( - HloInstruction::CreateCustomCall(shape, operands, *custom_call_target, - opaque.has_value() ? *opaque : "")); + if (operand_layout_constraints.has_value()) { + if (!LayoutUtil::HasLayout(shape)) { + return Error(lexer_.GetLoc(), + "Layout must be set on layout-constrained custom call"); + } + if (operands.size() != operand_layout_constraints->size()) { + return Error(lexer_.GetLoc(), + StrCat("Expected ", operands.size(), + " operand layout constraints, ", + operand_layout_constraints->size(), " given")); + } + for (int64 i = 0; i < operands.size(); ++i) { + const Shape& operand_shape_with_layout = + (*operand_layout_constraints)[i]; + if (!LayoutUtil::HasLayout(operand_shape_with_layout)) { + return Error(lexer_.GetLoc(), + StrCat("Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout( + operand_shape_with_layout), + " for operand ", i, " does not have a layout")); + } + if (!ShapeUtil::Compatible(operand_shape_with_layout, + operands[i]->shape())) { + return Error( + lexer_.GetLoc(), + StrCat( + "Operand layout constraint shape ", + ShapeUtil::HumanStringWithLayout(operand_shape_with_layout), + " for operand ", i, + " is not compatible with operand shape ", + ShapeUtil::HumanStringWithLayout(operands[i]->shape()))); + } + } + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, *operand_layout_constraints, + opaque.has_value() ? *opaque : "")); + } else { + instruction = builder->AddInstruction(HloInstruction::CreateCustomCall( + shape, operands, *custom_call_target, + opaque.has_value() ? *opaque : "")); + } if (window.has_value()) { instruction->set_window(*window); } @@ -2533,6 +2576,15 @@ bool HloParser::ParseAttributeHelper( ->emplace(result); return true; } + case AttrTy::kShapeList: { + std::vector<Shape> result; + if (!ParseShapeList(&result)) { + return false; + } + static_cast<optional<std::vector<Shape>>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2825,6 +2877,23 @@ bool HloParser::ParsePrecisionList( parse_and_add_item); } +// shapelist ::= '{' shapes '}' +// precision_elements +// ::= /*empty*/ +// ::= shape (',' shape)* +bool HloParser::ParseShapeList(std::vector<Shape>* result) { + auto parse_and_add_item = [&]() { + Shape shape; + if (!ParseShape(&shape)) { + return false; + } + result->push_back(std::move(shape)); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2832,23 +2901,15 @@ bool HloParser::ParsePrecisionList( bool HloParser::ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - tensorflow::int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); + auto parse_and_add_item = [&]() { + tensorflow::int64 i; + if (!ParseInt64(&i)) { + return false; + } + result->push_back(i); + return true; + }; + return ParseList(start, end, delim, parse_and_add_item); } bool HloParser::ParseList(const TokKind start, const TokKind end, |