aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/compiler/xla/tools/parser/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc1017
1 files changed, 124 insertions, 893 deletions
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index fed0492a54..6c2e37e3b5 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -28,9 +28,6 @@ namespace tools {
namespace {
using tensorflow::StringPiece;
-using tensorflow::gtl::optional;
-using tensorflow::str_util::Split;
-using tensorflow::str_util::SplitAndParseAsInts;
using tensorflow::strings::Printf;
using tensorflow::strings::StrAppend;
using tensorflow::strings::StrCat;
@@ -60,6 +57,7 @@ class HloParser {
bool ParseInstructionList(HloComputation::Builder* builder,
string* root_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseSharding(HloInstruction* instruction);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape);
@@ -80,73 +78,10 @@ class HloParser {
bool ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size);
- // Describes the start, limit, and stride on every dimension of the operand
- // being sliced.
- struct SliceRanges {
- std::vector<int64> starts;
- std::vector<int64> limits;
- std::vector<int64> strides;
- };
-
- // Types of attributes.
- enum class AttrTy {
- kInt64,
- kFloat,
- kBracedInt64List,
- kHloComputation,
- kWindow,
- kConvolutionDimensionNumbers,
- kSharding,
- kInstructionList,
- kSliceRanges,
- kPaddingConfig,
- };
-
- struct AttrConfig {
- bool required; // whether it's required or optional
- AttrTy attr_type; // what type it is
- void* result; // where to store the parsed result.
- };
-
- // Parses attributes given names and configs of the attributes. Each parsed
- // result is passed back through the result pointer in corresponding
- // AttrConfig. Note that the result pointer must point to a optional<T> typed
- // variable which outlives this function. Returns false on error. You should
- // not use the any of the results if this function failed.
- //
- // Example usage:
- //
- // std::unordered_map<string, AttrConfig> attrs;
- // optional<int64> foo;
- // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo};
- // optional<Window> bar;
- // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar};
- // if (!ParseAttribute(attrs)) {
- // return false; // Do not use 'foo' 'bar' if failed.
- // }
- // // Do something with 'bar'.
- // if (foo) { // If attr foo is seen, do something with 'foo'. }
- //
- bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs);
-
- // Parses a name and finds the corresponding hlo computation.
- bool ParseComputationName(HloComputation** value);
- // Parses a list of names and finds the corresponding hlo instructions.
- bool ParseInstructionNames(std::vector<HloInstruction*>* instructions);
- bool ParseWindow(Window* window);
- bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums);
- bool ParsePaddingConfig(PaddingConfig* padding);
- bool ParseSharding(OpSharding* sharding);
- bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed);
-
- // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3.
- bool ParseDxD(const string& name, std::vector<int64>* result);
- // Parses window's pad sub-attriute, e.g., pad=0_0x3x3.
- bool ParseWindowPad(std::vector<std::vector<int64>>* pad);
-
- bool ParseSliceRanges(SliceRanges* result);
- bool ParseInt64List(const TokKind start, const TokKind end,
- const TokKind delim, std::vector<int64>* result);
+ template <typename T>
+ bool ParseExtraAttribute(T* value, const string& expected_attribute);
+ template <typename T>
+ bool ParseAttributeValue(T* value);
bool ParseParamList();
bool ParseName(string* result);
@@ -279,7 +214,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
"expects '}' at the end of instruction list.");
}
-// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
+// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
@@ -295,15 +230,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (is_root) {
*root_name = name;
}
-
- // Add optional attributes.
- std::unordered_map<string, AttrConfig> attrs;
- optional<OpSharding> sharding;
- attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding};
- optional<std::vector<HloInstruction*>> predecessors;
- attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList,
- &predecessors};
-
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -311,8 +237,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before parameter number") ||
!ParseInt64(&parameter_number) ||
- !ParseToken(TokKind::kRparen, "expects ')' after parameter number") ||
- !ParseAttributes(attrs)) {
+ !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) {
return false;
}
instruction = builder->AddInstruction(
@@ -324,8 +249,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (!ParseToken(TokKind::kLparen,
"expects '(' before constant literal") ||
!ParseLiteral(&literal, shape) ||
- !ParseToken(TokKind::kRparen, "expects ')' after constant literal") ||
- !ParseAttributes(attrs)) {
+ !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) {
return false;
}
instruction = builder->AddInstruction(
@@ -351,8 +275,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kSin:
case HloOpcode::kSort:
case HloOpcode::kTanh: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1)) {
return false;
}
instruction = builder->AddInstruction(
@@ -382,8 +305,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
case HloOpcode::kShiftLeft:
case HloOpcode::kShiftRightArithmetic:
case HloOpcode::kShiftRightLogical: {
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/2)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateBinary(
@@ -393,8 +315,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
// Ternary ops.
case HloOpcode::kClamp:
case HloOpcode::kSelect: {
- if (!ParseOperands(&operands, /*expected_size=*/3) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/3)) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateTernary(
@@ -403,8 +324,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
}
// Other supported ops.
case HloOpcode::kConvert: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1)) {
return false;
}
instruction = builder->AddInstruction(
@@ -412,8 +332,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kCrossReplicaSum: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1)) {
return false;
}
instruction = builder->AddInstruction(
@@ -421,8 +340,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kReshape: {
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands, /*expected_size=*/1)) {
return false;
}
instruction = builder->AddInstruction(
@@ -430,7 +348,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kTuple: {
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
+ if (!ParseOperands(&operands)) {
return false;
}
instruction =
@@ -438,376 +356,126 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
break;
}
case HloOpcode::kWhile: {
- optional<HloComputation*> condition;
- optional<HloComputation*> body;
- attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation,
- &condition};
- attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body};
+ HloComputation* condition;
+ HloComputation* body;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ !ParseExtraAttribute(&condition,
+ /*expected_attribute=*/"condition") ||
+ !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) {
return false;
}
instruction = builder->AddInstruction(HloInstruction::CreateWhile(
- shape, *condition, *body, /*init=*/operands[0]));
+ shape, condition, body, /*init=*/operands[0]));
break;
}
case HloOpcode::kRecv: {
- optional<int64> channel_id;
- attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ int64 channel_id;
if (!ParseOperands(&operands, /*expected_size=*/0) ||
- !ParseAttributes(attrs)) {
+ !ParseExtraAttribute(&channel_id,
+ /*expected_attribute=*/"channel_id")) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id));
- break;
- }
- case HloOpcode::kRecvDone: {
- optional<int64> channel_id;
- attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- if (channel_id != operands[0]->channel_id()) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0]));
+ HloInstruction::CreateRecv(shape, channel_id));
break;
}
case HloOpcode::kSend: {
- optional<int64> channel_id;
- attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
+ int64 channel_id;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ !ParseExtraAttribute(&channel_id,
+ /*expected_attribute=*/"channel_id")) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateSend(operands[0], *channel_id));
- break;
- }
- case HloOpcode::kSendDone: {
- optional<int64> channel_id;
- attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- if (channel_id != operands[0]->channel_id()) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateSendDone(operands[0]));
+ HloInstruction::CreateSend(operands[0], channel_id));
break;
}
case HloOpcode::kGetTupleElement: {
- optional<int64> index;
- attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index};
+ int64 index;
if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateGetTupleElement(shape, operands[0], *index));
+ HloInstruction::CreateGetTupleElement(shape, operands[0], index));
break;
}
case HloOpcode::kCall: {
- optional<HloComputation*> to_apply;
- attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
- &to_apply};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(
- HloInstruction::CreateCall(shape, operands, *to_apply));
- break;
- }
- case HloOpcode::kReduceWindow: {
- optional<HloComputation*> reduce_computation;
- optional<Window> window;
- attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
- attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
- &reduce_computation};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow(
- shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window,
- *reduce_computation));
- break;
- }
- case HloOpcode::kConvolution: {
- optional<Window> window;
- optional<ConvolutionDimensionNumbers> dnums;
- attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
- attrs["dim_labels"] = {/*required=*/true,
- AttrTy::kConvolutionDimensionNumbers, &dnums};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateConvolve(
- shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums));
- break;
- }
- case HloOpcode::kBroadcast: {
- optional<std::vector<int64>> broadcast_dimensions;
- attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &broadcast_dimensions};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateBroadcast(
- shape, operands[0], *broadcast_dimensions));
- break;
- }
- case HloOpcode::kConcatenate: {
- optional<std::vector<int64>> dimensions;
- attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &dimensions};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs) ||
- dimensions->size() != 1) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateConcatenate(
- shape, operands, dimensions->at(0)));
- break;
- }
- case HloOpcode::kMap: {
- optional<HloComputation*> to_apply;
- attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
- &to_apply};
- if (!ParseOperands(&operands) || !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(
- HloInstruction::CreateMap(shape, operands, *to_apply));
- break;
- }
- case HloOpcode::kReduce: {
- optional<HloComputation*> reduce_computation;
- attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation,
- &reduce_computation};
- optional<std::vector<int64>> dimensions_to_reduce;
- attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &dimensions_to_reduce};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateReduce(
- shape, /*operand=*/operands[0], /*init_value=*/operands[1],
- *dimensions_to_reduce, *reduce_computation));
- break;
- }
- case HloOpcode::kReverse: {
- optional<std::vector<int64>> dimensions;
- attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &dimensions};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(
- HloInstruction::CreateReverse(shape, operands[0], *dimensions));
- break;
- }
- case HloOpcode::kSelectAndScatter: {
- optional<HloComputation*> select;
- attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select};
- optional<HloComputation*> scatter;
- attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter};
- optional<Window> window;
- attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window};
- if (!ParseOperands(&operands, /*expected_size=*/3) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateSelectAndScatter(
- shape, /*operand=*/operands[0], *select, *window,
- /*source=*/operands[1], /*init_value=*/operands[2], *scatter));
- break;
- }
- case HloOpcode::kSlice: {
- optional<SliceRanges> slice_ranges;
- attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateSlice(
- shape, operands[0], slice_ranges->starts, slice_ranges->limits,
- slice_ranges->strides));
- break;
- }
- case HloOpcode::kDynamicSlice: {
- optional<std::vector<int64>> dynamic_slice_sizes;
- attrs["dynamic_slice_sizes"] = {
- /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice(
- shape, /*operand=*/operands[0], /*start_indices=*/operands[1],
- *dynamic_slice_sizes));
- break;
- }
- case HloOpcode::kDynamicUpdateSlice: {
- if (!ParseOperands(&operands, /*expected_size=*/3) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice(
- shape, /*operand=*/operands[0], /*update=*/operands[1],
- /*start_indices=*/operands[2]));
- break;
- }
- case HloOpcode::kTranspose: {
- optional<std::vector<int64>> dimensions;
- attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List,
- &dimensions};
- if (!ParseOperands(&operands, /*expected_size=*/1) ||
- !ParseAttributes(attrs)) {
+ HloComputation* to_apply;
+ if (!ParseOperands(&operands) ||
+ !ParseExtraAttribute(&to_apply,
+ /*expected_attribute=*/"to_apply")) {
return false;
}
instruction = builder->AddInstruction(
- HloInstruction::CreateTranspose(shape, operands[0], *dimensions));
- break;
- }
- case HloOpcode::kBatchNormTraining: {
- optional<float> epsilon;
- attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
- attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
- &feature_index};
- if (!ParseOperands(&operands, /*expected_size=*/3) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateBatchNormTraining(
- shape, /*operand=*/operands[0], /*scale=*/operands[1],
- /*offset=*/operands[2], *epsilon, *feature_index));
- break;
- }
- case HloOpcode::kBatchNormInference: {
- optional<float> epsilon;
- attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
- attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
- &feature_index};
- if (!ParseOperands(&operands, /*expected_size=*/5) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction =
- builder->AddInstruction(HloInstruction::CreateBatchNormInference(
- shape, /*operand=*/operands[0], /*scale=*/operands[1],
- /*offset=*/operands[2], /*mean=*/operands[3],
- /*variance=*/operands[4], *epsilon, *feature_index));
- break;
- }
- case HloOpcode::kBatchNormGrad: {
- optional<float> epsilon;
- attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon};
- optional<int64> feature_index;
- attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64,
- &feature_index};
- if (!ParseOperands(&operands, /*expected_size=*/5) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad(
- shape, /*operand=*/operands[0], /*scale=*/operands[1],
- /*mean=*/operands[2], /*variance=*/operands[3],
- /*grad_output=*/operands[4], *epsilon, *feature_index));
- break;
- }
- case HloOpcode::kPad: {
- optional<PaddingConfig> padding;
- attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding};
- if (!ParseOperands(&operands, /*expected_size=*/2) ||
- !ParseAttributes(attrs)) {
- return false;
- }
- instruction = builder->AddInstruction(HloInstruction::CreatePad(
- shape, operands[0], /*padding_value=*/operands[1], *padding));
+ HloInstruction::CreateCall(shape, operands, to_apply));
break;
}
+ case HloOpcode::kBroadcast:
case HloOpcode::kCustomCall:
+ case HloOpcode::kConcatenate:
case HloOpcode::kReducePrecision:
+ case HloOpcode::kConvolution:
+ case HloOpcode::kMap:
+ case HloOpcode::kPad:
+ case HloOpcode::kReduce:
+ case HloOpcode::kReduceWindow:
+ case HloOpcode::kSelectAndScatter:
+ case HloOpcode::kReverse:
case HloOpcode::kRng:
+ case HloOpcode::kSlice:
+ case HloOpcode::kDynamicSlice:
+ case HloOpcode::kDynamicUpdateSlice:
+ case HloOpcode::kTranspose:
case HloOpcode::kFusion:
+ case HloOpcode::kBatchNormTraining:
+ case HloOpcode::kBatchNormInference:
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
+ case HloOpcode::kBatchNormGrad:
case HloOpcode::kTrace:
return TokenError(StrCat("parsing not yet implemented for op: ",
HloOpcodeString(opcode)));
}
- // Add common attrs (sharding, control predecessors) to the instruction, if
- // they were seen.
- if (sharding) {
- instruction->set_sharding(
- HloSharding::FromProto(sharding.value()).ValueOrDie());
- }
- if (predecessors) {
- for (auto* pre : *predecessors) {
- Status status = pre->AddControlDependencyTo(instruction);
- if (!status.ok()) {
- return TokenError(StrCat("error adding control dependency for: ", name,
- " status: ", status.ToString()));
- }
+ bool has_sharding = false;
+ bool has_control = false;
+ while (EatIfPresent(TokKind::kComma)) {
+ string attribute_name;
+ if (!ParseAttributeName(&attribute_name)) {
+ return TokenError("expects ', sharding=' or ', control-predecessors='");
}
- }
- return AddInstruction(name, instruction);
-}
-
-// ::= '{' (single_sharding | tuple_sharding) '}'
-//
-// tuple_sharding ::= single_sharding* (',' single_sharding)*
-bool HloParser::ParseSharding(OpSharding* sharding) {
- // A single sharding starts with '{' and is not followed by '{'.
- // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for
- // an empty tuple.
- if (!ParseToken(TokKind::kLbrace,
- "expected '{' to start sharding attribute")) {
- return false;
- }
- if (lexer_.GetKind() != TokKind::kLbrace &&
- lexer_.GetKind() != TokKind::kRbrace) {
- return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true);
- }
-
- // Tuple sharding.
- // Allow empty tuple shardings.
- if (lexer_.GetKind() != TokKind::kRbrace) {
- do {
- if (!ParseSingleSharding(sharding->add_tuple_shardings(),
- /*lbrace_pre_lexed=*/false)) {
+ if (attribute_name == "sharding") {
+ // Parse "sharding=".
+ if (has_sharding) {
+ return TokenError("expects at most 1 'sharding='");
+ }
+ has_sharding = true;
+ if (!ParseSharding(instruction)) {
return false;
}
- } while (EatIfPresent(TokKind::kComma));
+ } else if (attribute_name == "control-predecessors") {
+ // Parse "control-predecessors"
+ if (has_control) {
+ return TokenError("expects at most 1 'control-predecessors='");
+ }
+ has_control = true;
+ if (!ParseControlPredecessors(instruction)) {
+ return false;
+ }
+ } else {
+ return TokenError(StrCat("unexpected attribute: ", attribute_name));
+ }
}
- sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE);
- return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute");
+ return AddInstruction(name, instruction);
}
-// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape?
-// ('devices=' ('[' dims ']')* device_list)? '}'
-// dims ::= int_list device_list ::= int_list
-bool HloParser::ParseSingleSharding(OpSharding* sharding,
- bool lbrace_pre_lexed) {
- if (!lbrace_pre_lexed &&
- !ParseToken(TokKind::kLbrace,
+// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('['
+// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list
+bool HloParser::ParseSharding(HloInstruction* instruction) {
+ if (!ParseToken(TokKind::kLbrace,
"expected '{' to start sharding attribute")) {
return false;
}
@@ -877,6 +545,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
}
}
+ OpSharding sharding;
if (replicated) {
if (!devices.empty()) {
return TokenError(
@@ -886,7 +555,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
return TokenError(
"replicated shardings should not have any tile shape set");
}
- sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
+ sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED);
} else if (maximal) {
if (devices.size() != 1) {
return TokenError(
@@ -895,8 +564,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
if (!ShapeUtil::Equal(tile_shape, Shape())) {
return TokenError("maximal shardings should not have any tile shape set");
}
- sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
- sharding->add_tile_assignment_devices(devices[0]);
+ sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL);
+ sharding.add_tile_assignment_devices(devices[0]);
} else {
if (devices.size() <= 1) {
return TokenError(
@@ -910,43 +579,47 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding,
"non-maximal shardings must have a tile assignment list including "
"dimensions");
}
- sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER);
- *sharding->mutable_tile_shape() = tile_shape;
+ sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER);
+ *sharding.mutable_tile_shape() = tile_shape;
for (int64 dim : tile_assignment_dimensions) {
- sharding->add_tile_assignment_dimensions(dim);
+ sharding.add_tile_assignment_dimensions(dim);
}
for (int64 device : devices) {
- sharding->add_tile_assignment_devices(device);
+ sharding.add_tile_assignment_devices(device);
}
}
+ instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie());
lexer_.Lex();
return true;
}
// '{' name+ '}'
-bool HloParser::ParseInstructionNames(
- std::vector<HloInstruction*>* instructions) {
+bool HloParser::ParseControlPredecessors(HloInstruction* instruction) {
if (!ParseToken(TokKind::kLbrace,
- "expects '{' at the beginning of instruction name list")) {
+ "expects '{' at the beginning of control predecessors")) {
return false;
}
do {
string name;
if (!ParseName(&name)) {
- return TokenError("expects a instruction name");
+ return TokenError("expects a control predecessor");
}
- HloInstruction* instr =
+ HloInstruction* pre =
tensorflow::gtl::FindPtrOrNull(instruction_pool_, name);
- if (!instr) {
+ if (!pre) {
return TokenError(
- Printf("instruction '%s' is not defined", name.c_str()));
+ StrCat("control predecessor ", name, " is not defined: "));
+ }
+ Status status = pre->AddControlDependencyTo(instruction);
+ if (!status.ok()) {
+ return TokenError(StrCat("error adding control dependency for: ", name,
+ " status: ", status.ToString()));
}
- instructions->push_back(instr);
} while (EatIfPresent(TokKind::kComma));
return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of control instructions");
+ "expects '}' at the end of control predecessors");
}
bool HloParser::SetValueInLiteral(int64 value, int64 linear_index,
@@ -1284,134 +957,28 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
return true;
}
-bool HloParser::ParseAttributes(
- const std::unordered_map<string, AttrConfig>& attrs) {
- std::unordered_set<string> seen_attrs;
- while (EatIfPresent(TokKind::kComma)) {
- string name;
- if (!ParseAttributeName(&name)) {
- return TokenError("error parsing attributes");
- }
- VLOG(1) << "Parsing attribute " << name;
- if (!seen_attrs.insert(name).second) {
- return TokenError(Printf("attribute %s already exists", name.c_str()));
- }
- auto attr_it = attrs.find(name);
- if (attr_it == attrs.end()) {
- return TokenError(Printf("unexpected attribute %s", name.c_str()));
- }
- AttrTy attr_type = attr_it->second.attr_type;
- void* attr_out_ptr = attr_it->second.result;
- bool success = [&] {
- switch (attr_type) {
- case AttrTy::kInt64: {
- int64 result;
- if (!ParseInt64(&result)) {
- return false;
- }
- static_cast<optional<int64>*>(attr_out_ptr)->emplace(result);
- return true;
- }
- case AttrTy::kFloat: {
- double result;
- if (!ParseDouble(&result)) {
- return false;
- }
- if (result > std::numeric_limits<float>::max() ||
- result < std::numeric_limits<float>::lowest()) {
- return TokenError("value out of range for float");
- }
- static_cast<optional<float>*>(attr_out_ptr)
- ->emplace(static_cast<float>(result));
- return true;
- }
- case AttrTy::kHloComputation: {
- HloComputation* result;
- if (!ParseComputationName(&result)) {
- return false;
- }
- static_cast<optional<HloComputation*>*>(attr_out_ptr)
- ->emplace(result);
- return true;
- }
- case AttrTy::kWindow: {
- Window result;
- if (!ParseWindow(&result)) {
- return false;
- }
- static_cast<optional<Window>*>(attr_out_ptr)->emplace(result);
- return true;
- }
- case AttrTy::kConvolutionDimensionNumbers: {
- ConvolutionDimensionNumbers result;
- if (!ParseConvolutionDimensionNumbers(&result)) {
- return false;
- }
- static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr)
- ->emplace(result);
- return true;
- }
- case AttrTy::kSharding: {
- OpSharding sharding;
- if (!ParseSharding(&sharding)) {
- return false;
- }
- static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding);
- return true;
- }
- case AttrTy::kInstructionList: {
- std::vector<HloInstruction*> result;
- if (!ParseInstructionNames(&result)) {
- return false;
- }
- static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr)
- ->emplace(result);
- return true;
- }
- case AttrTy::kBracedInt64List: {
- std::vector<int64> result;
- if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace,
- TokKind::kComma, &result)) {
- return false;
- }
- static_cast<optional<std::vector<int64>>*>(attr_out_ptr)
- ->emplace(result);
- return true;
- }
- case AttrTy::kSliceRanges: {
- SliceRanges result;
- if (!ParseSliceRanges(&result)) {
- return false;
- }
- static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result);
- return true;
- }
- case AttrTy::kPaddingConfig: {
- PaddingConfig result;
- if (!ParsePaddingConfig(&result)) {
- return false;
- }
- static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result);
- return true;
- }
- }
- }();
- if (!success) {
- return TokenError(Printf("error parsing attribute %s", name.c_str()));
- }
+// extra_attribute ::= ',' attribute_name value
+template <typename T>
+bool HloParser::ParseExtraAttribute(T* value,
+ const string& expected_attribute) {
+ if (!ParseToken(TokKind::kComma,
+ "expects ',' in front of an extra attribute")) {
+ return false;
}
- // Check that all required attrs were seen.
- for (const auto& attr_it : attrs) {
- if (attr_it.second.required &&
- seen_attrs.find(attr_it.first) == seen_attrs.end()) {
- return TokenError(Printf("attribute %s is expected but not seen",
- attr_it.first.c_str()));
- }
+ string attribute_name;
+ if (!ParseAttributeName(&attribute_name) &&
+ attribute_name != expected_attribute) {
+ return TokenError(StrCat("expects attribute name: ", expected_attribute));
+ }
+ if (!ParseAttributeValue(value)) {
+ return TokenError(
+ StrCat("expects value for attribute: ", expected_attribute));
}
return true;
}
-bool HloParser::ParseComputationName(HloComputation** value) {
+template <>
+bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) {
string name;
if (!ParseName(&name)) {
return TokenError("expects computation name");
@@ -1423,269 +990,9 @@ bool HloParser::ParseComputationName(HloComputation** value) {
return true;
}
-// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}'
-// The subattributes can appear in any order. 'size=' is required, others are
-// optional.
-bool HloParser::ParseWindow(Window* window) {
- if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) {
- return false;
- }
-
- std::vector<int64> size;
- std::vector<int64> stride;
- std::vector<std::vector<int64>> pad;
- std::vector<int64> lhs_dilate;
- std::vector<int64> rhs_dilate;
- while (lexer_.GetKind() != TokKind::kRbrace) {
- string field_name;
- if (!ParseAttributeName(&field_name)) {
- return TokenError("expects sub-attributes in window");
- }
- bool ok = [&] {
- if (field_name == "size") {
- return ParseDxD("size", &size);
- }
- if (field_name == "stride") {
- return ParseDxD("stride", &stride);
- }
- if (field_name == "lhs_dilate") {
- return ParseDxD("lhs_dilate", &lhs_dilate);
- }
- if (field_name == "rhs_dilate") {
- return ParseDxD("rls_dilate", &rhs_dilate);
- }
- if (field_name == "pad") {
- return ParseWindowPad(&pad);
- }
- return TokenError(StrCat("unexpected attribute name: ", field_name));
- }();
- if (!ok) {
- return false;
- }
- }
-
- if (size.empty()) {
- return TokenError(
- "sub-attribute 'size=' is required in the window attribute");
- }
- if (!stride.empty() && stride.size() != size.size()) {
- return TokenError("expects 'stride=' has the same size as 'size='");
- }
- if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) {
- return TokenError("expects 'lhs_dilate=' has the same size as 'size='");
- }
- if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) {
- return TokenError("expects 'rhs_dilate=' has the same size as 'size='");
- }
- if (!pad.empty() && pad.size() != size.size()) {
- return TokenError("expects 'pad=' has the same size as 'size='");
- }
-
- for (int i = 0; i < size.size(); i++) {
- window->add_dimensions()->set_size(size[i]);
- if (!pad.empty()) {
- window->mutable_dimensions(i)->set_padding_low(pad[i][0]);
- window->mutable_dimensions(i)->set_padding_high(pad[i][1]);
- }
- // If some field is not present, it has the default value.
- window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]);
- window->mutable_dimensions(i)->set_base_dilation(
- lhs_dilate.empty() ? 1 : lhs_dilate[i]);
- window->mutable_dimensions(i)->set_window_dilation(
- rhs_dilate.empty() ? 1 : rhs_dilate[i]);
- }
- return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute");
-}
-
-// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString.
-// The string looks like "dim_labels=0bf_0io->0bf".
-bool HloParser::ParseConvolutionDimensionNumbers(
- ConvolutionDimensionNumbers* dnums) {
- if (lexer_.GetKind() != TokKind::kDimLabels) {
- return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'");
- }
- string str = lexer_.GetStrVal();
-
- // The str is expected to have 3 items, lhs, rhs, out, and it must looks like
- // lhs_rhs->out, that is, the first separator is "_" and the second is "->".
- // So we replace the "->" with "_" and then split on "_".
- str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->",
- /*newsub=*/"_",
- /*replace_all=*/false);
- std::vector<string> lhs_rhs_out = Split(str, "_");
- if (lhs_rhs_out.size() != 3) {
- LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees "
- << str;
- }
-
- const int64 rank = lhs_rhs_out[0].length();
- if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) {
- return TokenError(
- "convolution lhs, rhs, and output must have the same rank");
- }
- if (rank < 3) {
- return TokenError("convolution rank must >=3");
- }
-
- auto is_unique = [](string str) -> bool {
- std::sort(str.begin(), str.end());
- return std::unique(str.begin(), str.end()) == str.end();
- };
-
- // lhs
- {
- const string& lhs = lhs_rhs_out[0];
- if (!is_unique(lhs)) {
- return TokenError(
- StrCat("expects unique lhs dimension numbers, but sees ", lhs));
- }
- for (int i = 0; i < rank - 2; i++) {
- dnums->add_spatial_dimensions(-1);
- }
- for (int i = 0; i < rank; i++) {
- char c = lhs[i];
- if (c == 'b') {
- dnums->set_input_batch_dimension(i);
- } else if (c == 'f') {
- dnums->set_input_feature_dimension(i);
- } else if (c < '0' + rank && c >= '0') {
- dnums->set_spatial_dimensions(c - '0', i);
- } else {
- return TokenError(
- Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1));
- }
- }
- }
- // rhs
- {
- const string& rhs = lhs_rhs_out[1];
- if (!is_unique(rhs)) {
- return TokenError(
- StrCat("expects unique rhs dimension numbers, but sees ", rhs));
- }
- for (int i = 0; i < rank - 2; i++) {
- dnums->add_kernel_spatial_dimensions(-1);
- }
- for (int i = 0; i < rank; i++) {
- char c = rhs[i];
- if (c == 'i') {
- dnums->set_kernel_input_feature_dimension(i);
- } else if (c == 'o') {
- dnums->set_kernel_output_feature_dimension(i);
- } else if (c < '0' + rank && c >= '0') {
- dnums->set_kernel_spatial_dimensions(c - '0', i);
- } else {
- return TokenError(
- Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1));
- }
- }
- }
- // output
- {
- const string& out = lhs_rhs_out[2];
- if (!is_unique(out)) {
- return TokenError(
- StrCat("expects unique output dimension numbers, but sees ", out));
- }
- for (int i = 0; i < rank; i++) {
- char c = out[i];
- if (c == 'b') {
- dnums->set_output_batch_dimension(i);
- } else if (c == 'f') {
- dnums->set_output_feature_dimension(i);
- } else if (c < '0' + rank && c >= '0') {
- if (dnums->spatial_dimensions(c - '0') != i) {
- return TokenError(
- "output spatial dimensions should be the same as input spatial "
- "dimensions");
- }
- } else {
- return TokenError(
- Printf("expects [0-%lldbf] in output dimension numbers", rank - 1));
- }
- }
- }
-
- lexer_.Lex();
- return true;
-}
-
-// ::= '{' ranges '}'
-// ::= /*empty*/
-// ::= range (',' range)*
-// range ::= '[' start ':' limit (':' stride)? ']'
-//
-// The slice ranges are printed as:
-//
-// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...}
-//
-// This function extracts the starts, limits, and strides as 3 vectors to the
-// result. If stride is not present, stride is 1. For example, if the slice
-// ranges is printed as:
-//
-// {[2:3:4], [5:6:7], [8:9]}
-//
-// The the parsed result will be:
-//
-// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}}
-//
-bool HloParser::ParseSliceRanges(SliceRanges* result) {
- if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) {
- return false;
- }
- std::vector<std::vector<int64>> ranges;
- if (lexer_.GetKind() == TokKind::kRbrace) {
- // empty
- return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
- }
- do {
- ranges.emplace_back();
- if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon,
- &ranges.back())) {
- return false;
- }
- } while (EatIfPresent(TokKind::kComma));
-
- for (const auto& range : ranges) {
- if (range.size() != 2 && range.size() != 3) {
- return TokenError(Printf(
- "expects [start:limit:step] or [start:limit], but sees %ld elements.",
- range.size()));
- }
- }
-
- for (const auto& range : ranges) {
- result->starts.push_back(range[0]);
- result->limits.push_back(range[1]);
- result->strides.push_back(range.size() == 3 ? range[2] : 1);
- }
- return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
-}
-
-// int64list ::= start int64_elements end
-// int64_elements
-// ::= /*empty*/
-// ::= int64_val (delim int64_val)*
-bool HloParser::ParseInt64List(const TokKind start, const TokKind end,
- const TokKind delim,
- std::vector<int64>* result) {
- if (!ParseToken(start, StrCat("expects an int64 list starting with ",
- TokKindToString(start)))) {
- return false;
- }
- if (lexer_.GetKind() == end) {
- // empty
- } else {
- do {
- int64 i;
- if (!ParseInt64(&i)) {
- return false;
- }
- result->push_back(i);
- } while (EatIfPresent(delim));
- }
- return ParseToken(
- end, StrCat("expects an int64 list to end with ", TokKindToString(end)));
+template <>
+bool HloParser::ParseAttributeValue<int64>(int64* value) {
+ return ParseInt64(value);
}
// param_list ::= '(' param_list1 ')'
@@ -1763,82 +1070,6 @@ bool HloParser::ParseAttributeName(string* result) {
return true;
}
-bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) {
- if (!result->empty()) {
- return TokenError(
- Printf("sub-attribute '%s=' already exists", name.c_str()));
- }
- // 1D
- if (lexer_.GetKind() == TokKind::kInt) {
- int64 number;
- if (!ParseInt64(&number)) {
- return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str()));
- }
- result->push_back(number);
- return true;
- }
- // 2D or higher.
- if (lexer_.GetKind() == TokKind::kDxD) {
- string str = lexer_.GetStrVal();
- if (!SplitAndParseAsInts(str, 'x', result)) {
- return TokenError(
- Printf("expects sub-attribute '%s=ixj...'", name.c_str()));
- }
- lexer_.Lex();
- return true;
- }
- return TokenError("expects token type kInt or kDxD");
-}
-
-bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) {
- if (!pad->empty()) {
- return TokenError("sub-attribute 'pad=' already exists");
- }
- if (lexer_.GetKind() != TokKind::kPad) {
- return TokenError("expects window pad pattern, e.g., '0_0x3_3'");
- }
- string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (int i = 0; i < padding_str.size(); i++) {
- std::vector<int64> low_high;
- if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) ||
- low_high.size() != 2) {
- return TokenError(
- "expects padding_low and padding_high separated by '_'");
- }
- pad->push_back(low_high);
- }
- lexer_.Lex();
- return true;
-}
-
-// This is the inverse xla::ToString(PaddingConfig). The padding config string
-// looks like "0_0_0x3_3_1". The string is first separated by 'x', each
-// substring represents one PaddingConfigDimension. The substring is 3 (or 2)
-// numbers joined by '_'.
-bool HloParser::ParsePaddingConfig(PaddingConfig* padding) {
- if (lexer_.GetKind() != TokKind::kPad) {
- return TokenError("expects padding config, e.g., '0_0_0x3_3_1'");
- }
- string str = lexer_.GetStrVal();
- std::vector<string> padding_str = Split(str, 'x');
- for (const auto& padding_dim_str : padding_str) {
- std::vector<int64> padding_dim;
- if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) ||
- (padding_dim.size() != 2 && padding_dim.size() != 3)) {
- return TokenError(
- "expects padding config pattern like 'low_high_interior' or "
- "'low_high'");
- }
- auto* dim = padding->add_dimensions();
- dim->set_edge_padding_low(padding_dim[0]);
- dim->set_edge_padding_high(padding_dim[1]);
- dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0);
- }
- lexer_.Lex();
- return true;
-}
-
bool HloParser::ParseOpcode(HloOpcode* result) {
VLOG(1) << "ParseOpcode";
if (lexer_.GetKind() != TokKind::kOpcode) {