diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 15:08:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 15:20:19 -0700 |
commit | bb84d5d5e309204110315f7d0ff8ca0dbb022dd2 (patch) | |
tree | 351ff0255434d39238315db811581cde201380c2 /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | cfec3aa38db1d2b70045e7b89d82fae87c3fec02 (diff) |
[XLA] Support parsing the canonical format of HLO text.
Also stop truncating operands in the canonical format.
PiperOrigin-RevId: 215466465
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 276 |
1 files changed, 181 insertions, 95 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 25b70740e3..5a125b4c08 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -80,17 +80,23 @@ class HloParser { StatusOr<PaddingConfig> ParsePaddingConfigOnly(); // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); + Status ParseSingleInstruction(HloModule* module); private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map<string, std::pair<HloInstruction*, LocTy>>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair<HloInstruction*, LocTy>* FindInstruction( const string& name, const optional<Shape>& shape = nullopt); @@ -98,9 +104,11 @@ class HloParser { bool ParseHloModule(HloModule* module); bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -281,23 +289,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map<string, std::pair<HloInstruction*, LocTy>> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector<InstrNameTable> scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector<InstrNameTable>* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector<InstrNameTable>* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map<string, std::pair<HloComputation*, LocTy>> computation_pool_; - HloLexer lexer_; std::vector<std::unique_ptr<HloComputation>> computations_; std::vector<string> error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function<std::pair<HloInstruction*, LocTy>*(string, - const optional<Shape>&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function<std::pair<HloInstruction*, LocTy>*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { @@ -351,11 +383,21 @@ bool HloParser::Run(HloModule* module) { std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( const string& name, const optional<Shape>& shape) { - std::pair<HloInstruction*, LocTy>* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair<HloInstruction*, LocTy>* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } @@ -439,7 +481,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique<HloComputation::Builder>(name); LocTy shape_loc = nullptr; Shape shape; @@ -447,40 +488,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -489,43 +511,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair<HloInstruction*, LocTy>* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector<HloInstruction*> operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -536,6 +577,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector<HloInstruction*> operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map<string, AttrConfig> attrs; optional<OpSharding> sharding; @@ -2146,7 +2200,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair<HloInstruction*, LocTy>* instruction = FindInstruction(name, shape); @@ -2299,9 +2366,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); return true; @@ -3134,7 +3209,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3204,36 +3279,51 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +Status HloParser::ParseSingleInstruction(HloModule* module) { + TF_RET_CHECK(create_missing_instruction_ == nullptr); + TF_RET_CHECK(scoped_name_tables_.empty()); + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair<HloInstruction*, LocTy>* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; // Prime the lexer. lexer_.Lex(); // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + } + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); } return Status::OK(); } @@ -3271,12 +3361,8 @@ Status ParseHloString(absl::string_view str, HloModule* module) { StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( absl::string_view str, absl::string_view name) { HloParser parser(str); - auto builder = absl::make_unique<HloComputation::Builder>(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr<HloComputation> computation = builder->Build(); auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig()); - module->AddEntryComputation(std::move(computation)); + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get())); return std::move(module); } |