aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
committerGravatar Avijit <Avijit.Chakraborty@intel.com>2018-08-12 16:21:41 -0700
commit9523a98466d16cf01fc76a67b489f1124cf626ac (patch)
treebd4c460b67fab60c2fb1a6c56bf22d1cbb5391e6 /tensorflow/compiler/xla/service/hlo_parser.cc
parent93e950c308071071f35d6dcb35b9f91b8a34876c (diff)
parent1a22b0b982fa1a953651b98af8f3cd30542048fd (diff)
Merge remote-tracking branch 'upstream/master'
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc82
1 files changed, 69 insertions, 13 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 93cc884e3a..2a8c6ecd92 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -125,6 +125,7 @@ class HloParser {
kFloat,
kString,
kBracedInt64List,
+ kBracedInt64ListList,
kHloComputation,
kFftType,
kWindow,
@@ -205,6 +206,10 @@ class HloParser {
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
+ // 'parse_and_add_item' is an lambda to parse an element in the list and add
+ // the parsed element to the result. It's supposed to capture the result.
+ bool ParseList(const TokKind start, const TokKind end, const TokKind delim,
+ const std::function<bool()>& parse_and_add_item);
bool ParseParamListToShape(Shape* shape, LocTy* shape_loc);
bool ParseParamList();
@@ -619,6 +624,28 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
break;
}
+ case HloOpcode::kAllToAll: {
+ optional<std::vector<std::vector<int64>>> tmp_groups;
+ optional<string> barrier;
+ attrs["replica_groups"] = {/*required=*/false,
+ AttrTy::kBracedInt64ListList, &tmp_groups};
+ attrs["barrier"] = {/*required=*/false, AttrTy::kString, &barrier};
+ if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ return false;
+ }
+ std::vector<ReplicaGroup> replica_groups;
+ if (tmp_groups) {
+ c_transform(*tmp_groups, std::back_inserter(replica_groups),
+ [](const std::vector<int64>& ids) {
+ ReplicaGroup group;
+ *group.mutable_replica_ids() = {ids.begin(), ids.end()};
+ return group;
+ });
+ }
+ instruction = builder->AddInstruction(HloInstruction::CreateAllToAll(
+ shape, operands, replica_groups, barrier ? *barrier : ""));
+ break;
+ }
case HloOpcode::kReshape: {
if (!ParseOperands(&operands, /*expected_size=*/1) ||
!ParseAttributes(attrs)) {
@@ -1383,7 +1410,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
bool replicated = false;
std::vector<tensorflow::int64> devices;
std::vector<tensorflow::int64> tile_assignment_dimensions;
- Shape tile_shape;
while (lexer_.GetKind() != TokKind::kRbrace) {
switch (lexer_.GetKind()) {
case TokKind::kw_maximal:
@@ -1434,7 +1460,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
break;
}
case TokKind::kShape:
- tile_shape = lexer_.GetShapeVal();
+ // TODO(b/112302613): Left here for backward compatibility to ignore the
+ // removed tile shape data.
lexer_.Lex();
break;
case TokKind::kRbrace:
@@ -1449,19 +1476,12 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(loc,
"replicated shardings should not have any devices assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc,
- "replicated shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return Error(loc,
"maximal shardings should have exactly one device assigned");
}
- if (!ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "maximal shardings should not have any tile shape set");
- }
sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
sharding->add_tile_assignment_devices(devices[0]);
} else {
@@ -1469,9 +1489,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return Error(
loc, "non-maximal shardings must have more than one device assigned");
}
- if (ShapeUtil::Equal(tile_shape, Shape())) {
- return Error(loc, "non-maximal shardings should have a tile shape set");
- }
if (tile_assignment_dimensions.empty()) {
return Error(
loc,
@@ -1479,7 +1496,6 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
"dimensions");
}
sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
- *sharding->mutable_tile_shape() = tile_shape;
for (tensorflow::int64 dim : tile_assignment_dimensions) {
sharding->add_tile_assignment_dimensions(dim);
}
@@ -2255,6 +2271,26 @@ bool HloParser::ParseAttributeHelper(
->emplace(result);
return true;
}
+ case AttrTy::kBracedInt64ListList: {
+ std::vector<std::vector<tensorflow::int64>> result;
+ auto parse_and_add_item = [&]() {
+ std::vector<tensorflow::int64> item;
+ if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
+ TokKind::kComma, &item)) {
+ return false;
+ }
+ result.push_back(item);
+ return true;
+ };
+ if (!ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item)) {
+ return false;
+ }
+ static_cast<optional<std::vector<std::vector<tensorflow::int64>>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
case AttrTy::kSliceRanges: {
SliceRanges result;
if (!ParseSliceRanges(&result)) {
@@ -2597,6 +2633,26 @@ bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
}
+bool HloParser::ParseList(const TokKind start, const TokKind end,
+ const TokKind delim,
+ const std::function<bool()>& parse_and_add_item) {
+ if (!ParseToken(start, StrCat("expects a list starting with ",
+ TokKindToString(start)))) {
+ return false;
+ }
+ if (lexer_.GetKind() == end) {
+ // empty
+ } else {
+ do {
+ if (!parse_and_add_item()) {
+ return false;
+ }
+ } while (EatIfPresent(delim));
+ }
+ return ParseToken(
+ end, StrCat("expects a list to end with ", TokKindToString(end)));
+}
+
// param_list_to_shape ::= param_list '->' shape
bool HloParser::ParseParamListToShape(Shape* shape, LocTy* shape_loc) {
if (!ParseParamList() || !ParseToken(TokKind::kArrow, "expects '->'")) {