diff options
Diffstat (limited to 'tensorflow/compiler/xla/tools/parser/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 1017 |
1 files changed, 124 insertions, 893 deletions
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index fed0492a54..6c2e37e3b5 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -28,9 +28,6 @@ namespace tools { namespace { using tensorflow::StringPiece; -using tensorflow::gtl::optional; -using tensorflow::str_util::Split; -using tensorflow::str_util::SplitAndParseAsInts; using tensorflow::strings::Printf; using tensorflow::strings::StrAppend; using tensorflow::strings::StrCat; @@ -60,6 +57,7 @@ class HloParser { bool ParseInstructionList(HloComputation::Builder* builder, string* root_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseSharding(HloInstruction* instruction); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); bool ParseTupleLiteral(std::unique_ptr<Literal>* literal, const Shape& shape); @@ -80,73 +78,10 @@ class HloParser { bool ParseOperands(std::vector<HloInstruction*>* operands, const int expected_size); - // Describes the start, limit, and stride on every dimension of the operand - // being sliced. - struct SliceRanges { - std::vector<int64> starts; - std::vector<int64> limits; - std::vector<int64> strides; - }; - - // Types of attributes. - enum class AttrTy { - kInt64, - kFloat, - kBracedInt64List, - kHloComputation, - kWindow, - kConvolutionDimensionNumbers, - kSharding, - kInstructionList, - kSliceRanges, - kPaddingConfig, - }; - - struct AttrConfig { - bool required; // whether it's required or optional - AttrTy attr_type; // what type it is - void* result; // where to store the parsed result. - }; - - // Parses attributes given names and configs of the attributes. Each parsed - // result is passed back through the result pointer in corresponding - // AttrConfig. Note that the result pointer must point to a optional<T> typed - // variable which outlives this function. Returns false on error. You should - // not use the any of the results if this function failed. - // - // Example usage: - // - // std::unordered_map<string, AttrConfig> attrs; - // optional<int64> foo; - // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; - // optional<Window> bar; - // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; - // if (!ParseAttribute(attrs)) { - // return false; // Do not use 'foo' 'bar' if failed. - // } - // // Do something with 'bar'. - // if (foo) { // If attr foo is seen, do something with 'foo'. } - // - bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs); - - // Parses a name and finds the corresponding hlo computation. - bool ParseComputationName(HloComputation** value); - // Parses a list of names and finds the corresponding hlo instructions. - bool ParseInstructionNames(std::vector<HloInstruction*>* instructions); - bool ParseWindow(Window* window); - bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); - bool ParsePaddingConfig(PaddingConfig* padding); - bool ParseSharding(OpSharding* sharding); - bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); - - // Parses a sub-attribute of the window attribute, e.g.,size=1x2x3. - bool ParseDxD(const string& name, std::vector<int64>* result); - // Parses window's pad sub-attriute, e.g., pad=0_0x3x3. - bool ParseWindowPad(std::vector<std::vector<int64>>* pad); - - bool ParseSliceRanges(SliceRanges* result); - bool ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, std::vector<int64>* result); + template <typename T> + bool ParseExtraAttribute(T* value, const string& expected_attribute); + template <typename T> + bool ParseAttributeValue(T* value); bool ParseParamList(); bool ParseName(string* result); @@ -279,7 +214,7 @@ bool HloParser::ParseInstructionList(HloComputation::Builder* builder, "expects '}' at the end of instruction list."); } -// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* +// instruction ::= ('ROOT')? name '=' shape opcode operands (extra_attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; @@ -295,15 +230,6 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (is_root) { *root_name = name; } - - // Add optional attributes. - std::unordered_map<string, AttrConfig> attrs; - optional<OpSharding> sharding; - attrs["sharding"] = {/*required=*/false, AttrTy::kSharding, &sharding}; - optional<std::vector<HloInstruction*>> predecessors; - attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, - &predecessors}; - HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -311,8 +237,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before parameter number") || !ParseInt64(¶meter_number) || - !ParseToken(TokKind::kRparen, "expects ')' after parameter number") || - !ParseAttributes(attrs)) { + !ParseToken(TokKind::kRparen, "expects ')' after parameter number")) { return false; } instruction = builder->AddInstruction( @@ -324,8 +249,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (!ParseToken(TokKind::kLparen, "expects '(' before constant literal") || !ParseLiteral(&literal, shape) || - !ParseToken(TokKind::kRparen, "expects ')' after constant literal") || - !ParseAttributes(attrs)) { + !ParseToken(TokKind::kRparen, "expects ')' after constant literal")) { return false; } instruction = builder->AddInstruction( @@ -351,8 +275,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kSin: case HloOpcode::kSort: case HloOpcode::kTanh: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1)) { return false; } instruction = builder->AddInstruction( @@ -382,8 +305,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, case HloOpcode::kShiftLeft: case HloOpcode::kShiftRightArithmetic: case HloOpcode::kShiftRightLogical: { - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/2)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateBinary( @@ -393,8 +315,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, // Ternary ops. case HloOpcode::kClamp: case HloOpcode::kSelect: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/3)) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateTernary( @@ -403,8 +324,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } // Other supported ops. case HloOpcode::kConvert: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1)) { return false; } instruction = builder->AddInstruction( @@ -412,8 +332,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kCrossReplicaSum: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1)) { return false; } instruction = builder->AddInstruction( @@ -421,8 +340,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kReshape: { - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + if (!ParseOperands(&operands, /*expected_size=*/1)) { return false; } instruction = builder->AddInstruction( @@ -430,7 +348,7 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kTuple: { - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + if (!ParseOperands(&operands)) { return false; } instruction = @@ -438,376 +356,126 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, break; } case HloOpcode::kWhile: { - optional<HloComputation*> condition; - optional<HloComputation*> body; - attrs["condition"] = {/*required=*/true, AttrTy::kHloComputation, - &condition}; - attrs["body"] = {/*required=*/true, AttrTy::kHloComputation, &body}; + HloComputation* condition; + HloComputation* body; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + !ParseExtraAttribute(&condition, + /*expected_attribute=*/"condition") || + !ParseExtraAttribute(&body, /*expected_attribute=*/"body")) { return false; } instruction = builder->AddInstruction(HloInstruction::CreateWhile( - shape, *condition, *body, /*init=*/operands[0])); + shape, condition, body, /*init=*/operands[0])); break; } case HloOpcode::kRecv: { - optional<int64> channel_id; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + int64 channel_id; if (!ParseOperands(&operands, /*expected_size=*/0) || - !ParseAttributes(attrs)) { + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateRecv(shape.tuple_shapes(0), *channel_id)); - break; - } - case HloOpcode::kRecvDone: { - optional<int64> channel_id; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - if (channel_id != operands[0]->channel_id()) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateRecvDone(operands[0])); + HloInstruction::CreateRecv(shape, channel_id)); break; } case HloOpcode::kSend: { - optional<int64> channel_id; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; + int64 channel_id; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + !ParseExtraAttribute(&channel_id, + /*expected_attribute=*/"channel_id")) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateSend(operands[0], *channel_id)); - break; - } - case HloOpcode::kSendDone: { - optional<int64> channel_id; - attrs["channel_id"] = {/*required=*/true, AttrTy::kInt64, &channel_id}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - if (channel_id != operands[0]->channel_id()) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateSendDone(operands[0])); + HloInstruction::CreateSend(operands[0], channel_id)); break; } case HloOpcode::kGetTupleElement: { - optional<int64> index; - attrs["index"] = {/*required=*/true, AttrTy::kInt64, &index}; + int64 index; if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + !ParseExtraAttribute(&index, /*expected_attribute=*/"index")) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateGetTupleElement(shape, operands[0], *index)); + HloInstruction::CreateGetTupleElement(shape, operands[0], index)); break; } case HloOpcode::kCall: { - optional<HloComputation*> to_apply; - attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, - &to_apply}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction( - HloInstruction::CreateCall(shape, operands, *to_apply)); - break; - } - case HloOpcode::kReduceWindow: { - optional<HloComputation*> reduce_computation; - optional<Window> window; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; - attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, - &reduce_computation}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateReduceWindow( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], *window, - *reduce_computation)); - break; - } - case HloOpcode::kConvolution: { - optional<Window> window; - optional<ConvolutionDimensionNumbers> dnums; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; - attrs["dim_labels"] = {/*required=*/true, - AttrTy::kConvolutionDimensionNumbers, &dnums}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateConvolve( - shape, /*lhs=*/operands[0], /*rhs=*/operands[1], *window, *dnums)); - break; - } - case HloOpcode::kBroadcast: { - optional<std::vector<int64>> broadcast_dimensions; - attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, - &broadcast_dimensions}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateBroadcast( - shape, operands[0], *broadcast_dimensions)); - break; - } - case HloOpcode::kConcatenate: { - optional<std::vector<int64>> dimensions; - attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, - &dimensions}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs) || - dimensions->size() != 1) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateConcatenate( - shape, operands, dimensions->at(0))); - break; - } - case HloOpcode::kMap: { - optional<HloComputation*> to_apply; - attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, - &to_apply}; - if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction( - HloInstruction::CreateMap(shape, operands, *to_apply)); - break; - } - case HloOpcode::kReduce: { - optional<HloComputation*> reduce_computation; - attrs["to_apply"] = {/*required=*/true, AttrTy::kHloComputation, - &reduce_computation}; - optional<std::vector<int64>> dimensions_to_reduce; - attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, - &dimensions_to_reduce}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateReduce( - shape, /*operand=*/operands[0], /*init_value=*/operands[1], - *dimensions_to_reduce, *reduce_computation)); - break; - } - case HloOpcode::kReverse: { - optional<std::vector<int64>> dimensions; - attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, - &dimensions}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction( - HloInstruction::CreateReverse(shape, operands[0], *dimensions)); - break; - } - case HloOpcode::kSelectAndScatter: { - optional<HloComputation*> select; - attrs["select"] = {/*required=*/true, AttrTy::kHloComputation, &select}; - optional<HloComputation*> scatter; - attrs["scatter"] = {/*required=*/true, AttrTy::kHloComputation, &scatter}; - optional<Window> window; - attrs["window"] = {/*required=*/true, AttrTy::kWindow, &window}; - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateSelectAndScatter( - shape, /*operand=*/operands[0], *select, *window, - /*source=*/operands[1], /*init_value=*/operands[2], *scatter)); - break; - } - case HloOpcode::kSlice: { - optional<SliceRanges> slice_ranges; - attrs["slice"] = {/*required=*/true, AttrTy::kSliceRanges, &slice_ranges}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateSlice( - shape, operands[0], slice_ranges->starts, slice_ranges->limits, - slice_ranges->strides)); - break; - } - case HloOpcode::kDynamicSlice: { - optional<std::vector<int64>> dynamic_slice_sizes; - attrs["dynamic_slice_sizes"] = { - /*required=*/true, AttrTy::kBracedInt64List, &dynamic_slice_sizes}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateDynamicSlice( - shape, /*operand=*/operands[0], /*start_indices=*/operands[1], - *dynamic_slice_sizes)); - break; - } - case HloOpcode::kDynamicUpdateSlice: { - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateDynamicUpdateSlice( - shape, /*operand=*/operands[0], /*update=*/operands[1], - /*start_indices=*/operands[2])); - break; - } - case HloOpcode::kTranspose: { - optional<std::vector<int64>> dimensions; - attrs["dimensions"] = {/*required=*/true, AttrTy::kBracedInt64List, - &dimensions}; - if (!ParseOperands(&operands, /*expected_size=*/1) || - !ParseAttributes(attrs)) { + HloComputation* to_apply; + if (!ParseOperands(&operands) || + !ParseExtraAttribute(&to_apply, + /*expected_attribute=*/"to_apply")) { return false; } instruction = builder->AddInstruction( - HloInstruction::CreateTranspose(shape, operands[0], *dimensions)); - break; - } - case HloOpcode::kBatchNormTraining: { - optional<float> epsilon; - attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional<int64> feature_index; - attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, - &feature_index}; - if (!ParseOperands(&operands, /*expected_size=*/3) || - !ParseAttributes(attrs)) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateBatchNormTraining( - shape, /*operand=*/operands[0], /*scale=*/operands[1], - /*offset=*/operands[2], *epsilon, *feature_index)); - break; - } - case HloOpcode::kBatchNormInference: { - optional<float> epsilon; - attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional<int64> feature_index; - attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, - &feature_index}; - if (!ParseOperands(&operands, /*expected_size=*/5) || - !ParseAttributes(attrs)) { - return false; - } - instruction = - builder->AddInstruction(HloInstruction::CreateBatchNormInference( - shape, /*operand=*/operands[0], /*scale=*/operands[1], - /*offset=*/operands[2], /*mean=*/operands[3], - /*variance=*/operands[4], *epsilon, *feature_index)); - break; - } - case HloOpcode::kBatchNormGrad: { - optional<float> epsilon; - attrs["epsilon"] = {/*required=*/true, AttrTy::kFloat, &epsilon}; - optional<int64> feature_index; - attrs["feature_index"] = {/*required=*/true, AttrTy::kInt64, - &feature_index}; - if (!ParseOperands(&operands, /*expected_size=*/5) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreateBatchNormGrad( - shape, /*operand=*/operands[0], /*scale=*/operands[1], - /*mean=*/operands[2], /*variance=*/operands[3], - /*grad_output=*/operands[4], *epsilon, *feature_index)); - break; - } - case HloOpcode::kPad: { - optional<PaddingConfig> padding; - attrs["padding"] = {/*required=*/true, AttrTy::kPaddingConfig, &padding}; - if (!ParseOperands(&operands, /*expected_size=*/2) || - !ParseAttributes(attrs)) { - return false; - } - instruction = builder->AddInstruction(HloInstruction::CreatePad( - shape, operands[0], /*padding_value=*/operands[1], *padding)); + HloInstruction::CreateCall(shape, operands, to_apply)); break; } + case HloOpcode::kBroadcast: case HloOpcode::kCustomCall: + case HloOpcode::kConcatenate: case HloOpcode::kReducePrecision: + case HloOpcode::kConvolution: + case HloOpcode::kMap: + case HloOpcode::kPad: + case HloOpcode::kReduce: + case HloOpcode::kReduceWindow: + case HloOpcode::kSelectAndScatter: + case HloOpcode::kReverse: case HloOpcode::kRng: + case HloOpcode::kSlice: + case HloOpcode::kDynamicSlice: + case HloOpcode::kDynamicUpdateSlice: + case HloOpcode::kTranspose: case HloOpcode::kFusion: + case HloOpcode::kBatchNormTraining: + case HloOpcode::kBatchNormInference: case HloOpcode::kInfeed: case HloOpcode::kOutfeed: + case HloOpcode::kBatchNormGrad: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); } - // Add common attrs (sharding, control predecessors) to the instruction, if - // they were seen. - if (sharding) { - instruction->set_sharding( - HloSharding::FromProto(sharding.value()).ValueOrDie()); - } - if (predecessors) { - for (auto* pre : *predecessors) { - Status status = pre->AddControlDependencyTo(instruction); - if (!status.ok()) { - return TokenError(StrCat("error adding control dependency for: ", name, - " status: ", status.ToString())); - } + bool has_sharding = false; + bool has_control = false; + while (EatIfPresent(TokKind::kComma)) { + string attribute_name; + if (!ParseAttributeName(&attribute_name)) { + return TokenError("expects ', sharding=' or ', control-predecessors='"); } - } - return AddInstruction(name, instruction); -} - -// ::= '{' (single_sharding | tuple_sharding) '}' -// -// tuple_sharding ::= single_sharding* (',' single_sharding)* -bool HloParser::ParseSharding(OpSharding* sharding) { - // A single sharding starts with '{' and is not followed by '{'. - // A tuple sharding starts with '{' and is followed by '{', or is '{''}' for - // an empty tuple. - if (!ParseToken(TokKind::kLbrace, - "expected '{' to start sharding attribute")) { - return false; - } - if (lexer_.GetKind() != TokKind::kLbrace && - lexer_.GetKind() != TokKind::kRbrace) { - return ParseSingleSharding(sharding, /*lbrace_pre_lexed=*/true); - } - - // Tuple sharding. - // Allow empty tuple shardings. - if (lexer_.GetKind() != TokKind::kRbrace) { - do { - if (!ParseSingleSharding(sharding->add_tuple_shardings(), - /*lbrace_pre_lexed=*/false)) { + if (attribute_name == "sharding") { + // Parse "sharding=". + if (has_sharding) { + return TokenError("expects at most 1 'sharding='"); + } + has_sharding = true; + if (!ParseSharding(instruction)) { return false; } - } while (EatIfPresent(TokKind::kComma)); + } else if (attribute_name == "control-predecessors") { + // Parse "control-predecessors" + if (has_control) { + return TokenError("expects at most 1 'control-predecessors='"); + } + has_control = true; + if (!ParseControlPredecessors(instruction)) { + return false; + } + } else { + return TokenError(StrCat("unexpected attribute: ", attribute_name)); + } } - sharding->set_type(OpSharding::Type::OpSharding_Type_TUPLE); - return ParseToken(TokKind::kRbrace, "expected '}' to end sharding attribute"); + return AddInstruction(name, instruction); } -// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? -// ('devices=' ('[' dims ']')* device_list)? '}' -// dims ::= int_list device_list ::= int_list -bool HloParser::ParseSingleSharding(OpSharding* sharding, - bool lbrace_pre_lexed) { - if (!lbrace_pre_lexed && - !ParseToken(TokKind::kLbrace, +// ::= '{' 'replicated'? 'maximal'? ('device=' int)? shape? ('devices=' ('[' +// dims ']')* device_list)? '}' dims ::= int_list device_list ::= int_list +bool HloParser::ParseSharding(HloInstruction* instruction) { + if (!ParseToken(TokKind::kLbrace, "expected '{' to start sharding attribute")) { return false; } @@ -877,6 +545,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, } } + OpSharding sharding; if (replicated) { if (!devices.empty()) { return TokenError( @@ -886,7 +555,7 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, return TokenError( "replicated shardings should not have any tile shape set"); } - sharding->set_type(OpSharding::Type::OpSharding_Type_REPLICATED); + sharding.set_type(OpSharding::Type::OpSharding_Type_REPLICATED); } else if (maximal) { if (devices.size() != 1) { return TokenError( @@ -895,8 +564,8 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, if (!ShapeUtil::Equal(tile_shape, Shape())) { return TokenError("maximal shardings should not have any tile shape set"); } - sharding->set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); - sharding->add_tile_assignment_devices(devices[0]); + sharding.set_type(OpSharding::Type::OpSharding_Type_MAXIMAL); + sharding.add_tile_assignment_devices(devices[0]); } else { if (devices.size() <= 1) { return TokenError( @@ -910,43 +579,47 @@ bool HloParser::ParseSingleSharding(OpSharding* sharding, "non-maximal shardings must have a tile assignment list including " "dimensions"); } - sharding->set_type(OpSharding::Type::OpSharding_Type_OTHER); - *sharding->mutable_tile_shape() = tile_shape; + sharding.set_type(OpSharding::Type::OpSharding_Type_OTHER); + *sharding.mutable_tile_shape() = tile_shape; for (int64 dim : tile_assignment_dimensions) { - sharding->add_tile_assignment_dimensions(dim); + sharding.add_tile_assignment_dimensions(dim); } for (int64 device : devices) { - sharding->add_tile_assignment_devices(device); + sharding.add_tile_assignment_devices(device); } } + instruction->set_sharding(HloSharding::FromProto(sharding).ValueOrDie()); lexer_.Lex(); return true; } // '{' name+ '}' -bool HloParser::ParseInstructionNames( - std::vector<HloInstruction*>* instructions) { +bool HloParser::ParseControlPredecessors(HloInstruction* instruction) { if (!ParseToken(TokKind::kLbrace, - "expects '{' at the beginning of instruction name list")) { + "expects '{' at the beginning of control predecessors")) { return false; } do { string name; if (!ParseName(&name)) { - return TokenError("expects a instruction name"); + return TokenError("expects a control predecessor"); } - HloInstruction* instr = + HloInstruction* pre = tensorflow::gtl::FindPtrOrNull(instruction_pool_, name); - if (!instr) { + if (!pre) { return TokenError( - Printf("instruction '%s' is not defined", name.c_str())); + StrCat("control predecessor ", name, " is not defined: ")); + } + Status status = pre->AddControlDependencyTo(instruction); + if (!status.ok()) { + return TokenError(StrCat("error adding control dependency for: ", name, + " status: ", status.ToString())); } - instructions->push_back(instr); } while (EatIfPresent(TokKind::kComma)); return ParseToken(TokKind::kRbrace, - "expects '}' at the end of control instructions"); + "expects '}' at the end of control predecessors"); } bool HloParser::SetValueInLiteral(int64 value, int64 linear_index, @@ -1284,134 +957,28 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands, return true; } -bool HloParser::ParseAttributes( - const std::unordered_map<string, AttrConfig>& attrs) { - std::unordered_set<string> seen_attrs; - while (EatIfPresent(TokKind::kComma)) { - string name; - if (!ParseAttributeName(&name)) { - return TokenError("error parsing attributes"); - } - VLOG(1) << "Parsing attribute " << name; - if (!seen_attrs.insert(name).second) { - return TokenError(Printf("attribute %s already exists", name.c_str())); - } - auto attr_it = attrs.find(name); - if (attr_it == attrs.end()) { - return TokenError(Printf("unexpected attribute %s", name.c_str())); - } - AttrTy attr_type = attr_it->second.attr_type; - void* attr_out_ptr = attr_it->second.result; - bool success = [&] { - switch (attr_type) { - case AttrTy::kInt64: { - int64 result; - if (!ParseInt64(&result)) { - return false; - } - static_cast<optional<int64>*>(attr_out_ptr)->emplace(result); - return true; - } - case AttrTy::kFloat: { - double result; - if (!ParseDouble(&result)) { - return false; - } - if (result > std::numeric_limits<float>::max() || - result < std::numeric_limits<float>::lowest()) { - return TokenError("value out of range for float"); - } - static_cast<optional<float>*>(attr_out_ptr) - ->emplace(static_cast<float>(result)); - return true; - } - case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; - } - static_cast<optional<HloComputation*>*>(attr_out_ptr) - ->emplace(result); - return true; - } - case AttrTy::kWindow: { - Window result; - if (!ParseWindow(&result)) { - return false; - } - static_cast<optional<Window>*>(attr_out_ptr)->emplace(result); - return true; - } - case AttrTy::kConvolutionDimensionNumbers: { - ConvolutionDimensionNumbers result; - if (!ParseConvolutionDimensionNumbers(&result)) { - return false; - } - static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr) - ->emplace(result); - return true; - } - case AttrTy::kSharding: { - OpSharding sharding; - if (!ParseSharding(&sharding)) { - return false; - } - static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding); - return true; - } - case AttrTy::kInstructionList: { - std::vector<HloInstruction*> result; - if (!ParseInstructionNames(&result)) { - return false; - } - static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr) - ->emplace(result); - return true; - } - case AttrTy::kBracedInt64List: { - std::vector<int64> result; - if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, - TokKind::kComma, &result)) { - return false; - } - static_cast<optional<std::vector<int64>>*>(attr_out_ptr) - ->emplace(result); - return true; - } - case AttrTy::kSliceRanges: { - SliceRanges result; - if (!ParseSliceRanges(&result)) { - return false; - } - static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result); - return true; - } - case AttrTy::kPaddingConfig: { - PaddingConfig result; - if (!ParsePaddingConfig(&result)) { - return false; - } - static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result); - return true; - } - } - }(); - if (!success) { - return TokenError(Printf("error parsing attribute %s", name.c_str())); - } +// extra_attribute ::= ',' attribute_name value +template <typename T> +bool HloParser::ParseExtraAttribute(T* value, + const string& expected_attribute) { + if (!ParseToken(TokKind::kComma, + "expects ',' in front of an extra attribute")) { + return false; } - // Check that all required attrs were seen. - for (const auto& attr_it : attrs) { - if (attr_it.second.required && - seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); - } + string attribute_name; + if (!ParseAttributeName(&attribute_name) && + attribute_name != expected_attribute) { + return TokenError(StrCat("expects attribute name: ", expected_attribute)); + } + if (!ParseAttributeValue(value)) { + return TokenError( + StrCat("expects value for attribute: ", expected_attribute)); } return true; } -bool HloParser::ParseComputationName(HloComputation** value) { +template <> +bool HloParser::ParseAttributeValue<HloComputation*>(HloComputation** value) { string name; if (!ParseName(&name)) { return TokenError("expects computation name"); @@ -1423,269 +990,9 @@ bool HloParser::ParseComputationName(HloComputation** value) { return true; } -// ::= '{' size stride? pad? lhs_dilate? rhs_dilate? '}' -// The subattributes can appear in any order. 'size=' is required, others are -// optional. -bool HloParser::ParseWindow(Window* window) { - if (!ParseToken(TokKind::kLbrace, "expected '{' to start window attribute")) { - return false; - } - - std::vector<int64> size; - std::vector<int64> stride; - std::vector<std::vector<int64>> pad; - std::vector<int64> lhs_dilate; - std::vector<int64> rhs_dilate; - while (lexer_.GetKind() != TokKind::kRbrace) { - string field_name; - if (!ParseAttributeName(&field_name)) { - return TokenError("expects sub-attributes in window"); - } - bool ok = [&] { - if (field_name == "size") { - return ParseDxD("size", &size); - } - if (field_name == "stride") { - return ParseDxD("stride", &stride); - } - if (field_name == "lhs_dilate") { - return ParseDxD("lhs_dilate", &lhs_dilate); - } - if (field_name == "rhs_dilate") { - return ParseDxD("rls_dilate", &rhs_dilate); - } - if (field_name == "pad") { - return ParseWindowPad(&pad); - } - return TokenError(StrCat("unexpected attribute name: ", field_name)); - }(); - if (!ok) { - return false; - } - } - - if (size.empty()) { - return TokenError( - "sub-attribute 'size=' is required in the window attribute"); - } - if (!stride.empty() && stride.size() != size.size()) { - return TokenError("expects 'stride=' has the same size as 'size='"); - } - if (!lhs_dilate.empty() && lhs_dilate.size() != size.size()) { - return TokenError("expects 'lhs_dilate=' has the same size as 'size='"); - } - if (!rhs_dilate.empty() && rhs_dilate.size() != size.size()) { - return TokenError("expects 'rhs_dilate=' has the same size as 'size='"); - } - if (!pad.empty() && pad.size() != size.size()) { - return TokenError("expects 'pad=' has the same size as 'size='"); - } - - for (int i = 0; i < size.size(); i++) { - window->add_dimensions()->set_size(size[i]); - if (!pad.empty()) { - window->mutable_dimensions(i)->set_padding_low(pad[i][0]); - window->mutable_dimensions(i)->set_padding_high(pad[i][1]); - } - // If some field is not present, it has the default value. - window->mutable_dimensions(i)->set_stride(stride.empty() ? 1 : stride[i]); - window->mutable_dimensions(i)->set_base_dilation( - lhs_dilate.empty() ? 1 : lhs_dilate[i]); - window->mutable_dimensions(i)->set_window_dilation( - rhs_dilate.empty() ? 1 : rhs_dilate[i]); - } - return ParseToken(TokKind::kRbrace, "expected '}' to end window attribute"); -} - -// This is the inverse of HloInstruction::ConvolutionDimensionNumbersToString. -// The string looks like "dim_labels=0bf_0io->0bf". -bool HloParser::ParseConvolutionDimensionNumbers( - ConvolutionDimensionNumbers* dnums) { - if (lexer_.GetKind() != TokKind::kDimLabels) { - return TokenError("expects dim labels pattern, e.g., 'bf0_0io->0bf'"); - } - string str = lexer_.GetStrVal(); - - // The str is expected to have 3 items, lhs, rhs, out, and it must looks like - // lhs_rhs->out, that is, the first separator is "_" and the second is "->". - // So we replace the "->" with "_" and then split on "_". - str = tensorflow::str_util::StringReplace(str, /*oldsub=*/"->", - /*newsub=*/"_", - /*replace_all=*/false); - std::vector<string> lhs_rhs_out = Split(str, "_"); - if (lhs_rhs_out.size() != 3) { - LOG(FATAL) << "expects 3 items: lhs, rhs, and output dims, but sees " - << str; - } - - const int64 rank = lhs_rhs_out[0].length(); - if (rank != lhs_rhs_out[1].length() || rank != lhs_rhs_out[2].length()) { - return TokenError( - "convolution lhs, rhs, and output must have the same rank"); - } - if (rank < 3) { - return TokenError("convolution rank must >=3"); - } - - auto is_unique = [](string str) -> bool { - std::sort(str.begin(), str.end()); - return std::unique(str.begin(), str.end()) == str.end(); - }; - - // lhs - { - const string& lhs = lhs_rhs_out[0]; - if (!is_unique(lhs)) { - return TokenError( - StrCat("expects unique lhs dimension numbers, but sees ", lhs)); - } - for (int i = 0; i < rank - 2; i++) { - dnums->add_spatial_dimensions(-1); - } - for (int i = 0; i < rank; i++) { - char c = lhs[i]; - if (c == 'b') { - dnums->set_input_batch_dimension(i); - } else if (c == 'f') { - dnums->set_input_feature_dimension(i); - } else if (c < '0' + rank && c >= '0') { - dnums->set_spatial_dimensions(c - '0', i); - } else { - return TokenError( - Printf("expects [0-%lldbf] in lhs dimension numbers", rank - 1)); - } - } - } - // rhs - { - const string& rhs = lhs_rhs_out[1]; - if (!is_unique(rhs)) { - return TokenError( - StrCat("expects unique rhs dimension numbers, but sees ", rhs)); - } - for (int i = 0; i < rank - 2; i++) { - dnums->add_kernel_spatial_dimensions(-1); - } - for (int i = 0; i < rank; i++) { - char c = rhs[i]; - if (c == 'i') { - dnums->set_kernel_input_feature_dimension(i); - } else if (c == 'o') { - dnums->set_kernel_output_feature_dimension(i); - } else if (c < '0' + rank && c >= '0') { - dnums->set_kernel_spatial_dimensions(c - '0', i); - } else { - return TokenError( - Printf("expects [0-%lldio] in rhs dimension numbers", rank - 1)); - } - } - } - // output - { - const string& out = lhs_rhs_out[2]; - if (!is_unique(out)) { - return TokenError( - StrCat("expects unique output dimension numbers, but sees ", out)); - } - for (int i = 0; i < rank; i++) { - char c = out[i]; - if (c == 'b') { - dnums->set_output_batch_dimension(i); - } else if (c == 'f') { - dnums->set_output_feature_dimension(i); - } else if (c < '0' + rank && c >= '0') { - if (dnums->spatial_dimensions(c - '0') != i) { - return TokenError( - "output spatial dimensions should be the same as input spatial " - "dimensions"); - } - } else { - return TokenError( - Printf("expects [0-%lldbf] in output dimension numbers", rank - 1)); - } - } - } - - lexer_.Lex(); - return true; -} - -// ::= '{' ranges '}' -// ::= /*empty*/ -// ::= range (',' range)* -// range ::= '[' start ':' limit (':' stride)? ']' -// -// The slice ranges are printed as: -// -// {[dim0_start:dim0_limit:dim0stride], [dim1_start:dim1_limit], ...} -// -// This function extracts the starts, limits, and strides as 3 vectors to the -// result. If stride is not present, stride is 1. For example, if the slice -// ranges is printed as: -// -// {[2:3:4], [5:6:7], [8:9]} -// -// The the parsed result will be: -// -// {/*starts=*/{2, 5, 8}, /*limits=*/{3, 6, 9}, /*strides=*/{4, 7, 1}} -// -bool HloParser::ParseSliceRanges(SliceRanges* result) { - if (!ParseToken(TokKind::kLbrace, "expects '{' to start ranges")) { - return false; - } - std::vector<std::vector<int64>> ranges; - if (lexer_.GetKind() == TokKind::kRbrace) { - // empty - return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); - } - do { - ranges.emplace_back(); - if (!ParseInt64List(TokKind::kLsquare, TokKind::kRsquare, TokKind::kColon, - &ranges.back())) { - return false; - } - } while (EatIfPresent(TokKind::kComma)); - - for (const auto& range : ranges) { - if (range.size() != 2 && range.size() != 3) { - return TokenError(Printf( - "expects [start:limit:step] or [start:limit], but sees %ld elements.", - range.size())); - } - } - - for (const auto& range : ranges) { - result->starts.push_back(range[0]); - result->limits.push_back(range[1]); - result->strides.push_back(range.size() == 3 ? range[2] : 1); - } - return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); -} - -// int64list ::= start int64_elements end -// int64_elements -// ::= /*empty*/ -// ::= int64_val (delim int64_val)* -bool HloParser::ParseInt64List(const TokKind start, const TokKind end, - const TokKind delim, - std::vector<int64>* result) { - if (!ParseToken(start, StrCat("expects an int64 list starting with ", - TokKindToString(start)))) { - return false; - } - if (lexer_.GetKind() == end) { - // empty - } else { - do { - int64 i; - if (!ParseInt64(&i)) { - return false; - } - result->push_back(i); - } while (EatIfPresent(delim)); - } - return ParseToken( - end, StrCat("expects an int64 list to end with ", TokKindToString(end))); +template <> +bool HloParser::ParseAttributeValue<int64>(int64* value) { + return ParseInt64(value); } // param_list ::= '(' param_list1 ')' @@ -1763,82 +1070,6 @@ bool HloParser::ParseAttributeName(string* result) { return true; } -bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) { - if (!result->empty()) { - return TokenError( - Printf("sub-attribute '%s=' already exists", name.c_str())); - } - // 1D - if (lexer_.GetKind() == TokKind::kInt) { - int64 number; - if (!ParseInt64(&number)) { - return TokenError(Printf("expects sub-attribute '%s=i'", name.c_str())); - } - result->push_back(number); - return true; - } - // 2D or higher. - if (lexer_.GetKind() == TokKind::kDxD) { - string str = lexer_.GetStrVal(); - if (!SplitAndParseAsInts(str, 'x', result)) { - return TokenError( - Printf("expects sub-attribute '%s=ixj...'", name.c_str())); - } - lexer_.Lex(); - return true; - } - return TokenError("expects token type kInt or kDxD"); -} - -bool HloParser::ParseWindowPad(std::vector<std::vector<int64>>* pad) { - if (!pad->empty()) { - return TokenError("sub-attribute 'pad=' already exists"); - } - if (lexer_.GetKind() != TokKind::kPad) { - return TokenError("expects window pad pattern, e.g., '0_0x3_3'"); - } - string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (int i = 0; i < padding_str.size(); i++) { - std::vector<int64> low_high; - if (!SplitAndParseAsInts(padding_str[i], '_', &low_high) || - low_high.size() != 2) { - return TokenError( - "expects padding_low and padding_high separated by '_'"); - } - pad->push_back(low_high); - } - lexer_.Lex(); - return true; -} - -// This is the inverse xla::ToString(PaddingConfig). The padding config string -// looks like "0_0_0x3_3_1". The string is first separated by 'x', each -// substring represents one PaddingConfigDimension. The substring is 3 (or 2) -// numbers joined by '_'. -bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { - if (lexer_.GetKind() != TokKind::kPad) { - return TokenError("expects padding config, e.g., '0_0_0x3_3_1'"); - } - string str = lexer_.GetStrVal(); - std::vector<string> padding_str = Split(str, 'x'); - for (const auto& padding_dim_str : padding_str) { - std::vector<int64> padding_dim; - if (!SplitAndParseAsInts(padding_dim_str, '_', &padding_dim) || - (padding_dim.size() != 2 && padding_dim.size() != 3)) { - return TokenError( - "expects padding config pattern like 'low_high_interior' or " - "'low_high'"); - } - auto* dim = padding->add_dimensions(); - dim->set_edge_padding_low(padding_dim[0]); - dim->set_edge_padding_high(padding_dim[1]); - dim->set_interior_padding(padding_dim.size() == 3 ? padding_dim[2] : 0); - } - lexer_.Lex(); - return true; -} - bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kOpcode) { |