aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc101
1 files changed, 81 insertions, 20 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index dd62988bcc..96f9ff6654 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -174,6 +174,7 @@ class HloParser {
kDistribution,
kDomain,
kPrecisionList,
+ kShapeList
};
struct AttrConfig {
@@ -240,6 +241,7 @@ class HloParser {
bool ParseSliceRanges(SliceRanges* result);
bool ParsePrecisionList(std::vector<PrecisionConfig::Precision>* result);
+ bool ParseShapeList(std::vector<Shape>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -1341,6 +1343,7 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
optional<Window> window;
optional<ConvolutionDimensionNumbers> dnums;
optional<int64> feature_group_count;
+ optional<std::vector<Shape>> operand_layout_constraints;
attrs["custom_call_target"] = {/*required=*/true, AttrTy::kString,
&custom_call_target};
attrs["opaque"] = {/*required=*/false, AttrTy::kString, &opaque};
@@ -1349,12 +1352,52 @@ bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
AttrTy::kConvolutionDimensionNumbers, &dnums};
attrs["feature_group_count"] = {/*required=*/false, AttrTy::kInt64,
&feature_group_count};
+ attrs["operand_layout_constraints"] = {
+ /*required=*/false, AttrTy::kShapeList, &operand_layout_constraints};
if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
return false;
}
- instruction = builder->AddInstruction(
- HloInstruction::CreateCustomCall(shape, operands, *custom_call_target,
- opaque.has_value() ? *opaque : ""));
+ if (operand_layout_constraints.has_value()) {
+ if (!LayoutUtil::HasLayout(shape)) {
+ return Error(lexer_.GetLoc(),
+ "Layout must be set on layout-constrained custom call");
+ }
+ if (operands.size() != operand_layout_constraints->size()) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Expected ", operands.size(),
+ " operand layout constraints, ",
+ operand_layout_constraints->size(), " given"));
+ }
+ for (int64 i = 0; i < operands.size(); ++i) {
+ const Shape& operand_shape_with_layout =
+ (*operand_layout_constraints)[i];
+ if (!LayoutUtil::HasLayout(operand_shape_with_layout)) {
+ return Error(lexer_.GetLoc(),
+ StrCat("Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(
+ operand_shape_with_layout),
+ " for operand ", i, " does not have a layout"));
+ }
+ if (!ShapeUtil::Compatible(operand_shape_with_layout,
+ operands[i]->shape())) {
+ return Error(
+ lexer_.GetLoc(),
+ StrCat(
+ "Operand layout constraint shape ",
+ ShapeUtil::HumanStringWithLayout(operand_shape_with_layout),
+ " for operand ", i,
+ " is not compatible with operand shape ",
+ ShapeUtil::HumanStringWithLayout(operands[i]->shape())));
+ }
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target, *operand_layout_constraints,
+ opaque.has_value() ? *opaque : ""));
+ } else {
+ instruction = builder->AddInstruction(HloInstruction::CreateCustomCall(
+ shape, operands, *custom_call_target,
+ opaque.has_value() ? *opaque : ""));
+ }
if (window.has_value()) {
instruction->set_window(*window);
}
@@ -2533,6 +2576,15 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kShapeList: {
+ std::vector<Shape> result;
+ if (!ParseShapeList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<Shape>>*>(attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2825,6 +2877,23 @@ bool HloParser::ParsePrecisionList(
parse_and_add_item);
}
+// shapelist ::= '{' shapes '}'
+// precision_elements
+// ::= /*empty*/
+// ::= shape (',' shape)*
+bool HloParser::ParseShapeList(std::vector<Shape>* result) {
+ auto parse_and_add_item = [&]() {
+ Shape shape;
+ if (!ParseShape(&shape)) {
+ return false;
+ }
+ result->push_back(std::move(shape));
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2832,23 +2901,15 @@ bool HloParser::ParsePrecisionList(
bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result) {
- if (!ParseToken(start, StrCat("expects an int64 list starting with ",
- TokKindToString(start)))) {
- return false;
- }
- if (lexer_.GetKind() == end) {
- // empty
- } else {
- do {
- tensorflow::int64 i;
- if (!ParseInt64(&i)) {
- return false;
- }
- result->push_back(i);
- } while (EatIfPresent(delim));
- }
- return ParseToken(
- end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
+ auto parse_and_add_item = [&]() {
+ tensorflow::int64 i;
+ if (!ParseInt64(&i)) {
+ return false;
+ }
+ result->push_back(i);
+ return true;
+ };
+ return ParseList(start, end, delim, parse_and_add_item);
}
bool HloParser::ParseList(const TokKind start, const TokKind end,