diff options
author | Chris Leary <leary@google.com> | 2018-08-21 17:41:59 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-21 17:45:23 -0700 |
commit | 274bebe3a69ff0af96c721cb9b7cfbc9cd87679d (patch) | |
tree | 72d9ce1bc954ae8dc7cf55c22beeac1e730eaeec /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | 9ed82b760567d8b4543045603546299e0dd2ae8a (diff) |
[XLA] Parse a single HLO instruction's text into a module.
Converts "free variable" operands into parameters.
PiperOrigin-RevId: 209691261
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 96 |
1 files changed, 88 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 44180a881e..b4793998ec 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -33,6 +33,7 @@ namespace xla { namespace { +using ::absl::nullopt; using ::absl::optional; using ::tensorflow::StringPiece; using ::tensorflow::str_util::Join; @@ -66,7 +67,21 @@ class HloParser { StatusOr<Window> ParseWindowOnly(); StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); + // Stand-alone parsing utility for a single instruction worth of text. + Status ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name); + private: + // Locates an instruction with the given name in the instruction_pool_ or + // returns nullptr. + // + // If the missing_instruction_hook_ is registered and a "shape" is provided, + // the hook will be called and may satisfy the request for the given + // instruction. This is useful when we reify parameters as they're resolved; + // i.e. for ParseSingleInstruction. + std::pair<HloInstruction*, LocTy>* FindInstruction( + const string& name, const optional<Shape>& shape = nullopt); + // ParseXXX returns false if an error occurred. bool ParseHloModule(); bool ParseComputations(); @@ -267,6 +282,12 @@ class HloParser { std::vector<std::unique_ptr<HloComputation>> computations_; const HloModuleConfig config_; std::vector<string> error_; + + // Function that gets invoked when we try to resolve an instruction + // instruction_pool_ but fail to do so. + std::function<std::pair<HloInstruction*, LocTy>*(string, + const optional<Shape>&)> + missing_instruction_hook_; }; bool HloParser::Error(LocTy loc, StringPiece msg) { @@ -293,6 +314,17 @@ bool HloParser::Run() { return ParseHloModule(); } +std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( + const string& name, const optional<Shape>& shape) { + std::pair<HloInstruction*, LocTy>* instr = + tensorflow::gtl::FindOrNull(instruction_pool_, name); + // Potentially call the missing instruction hook. + if (instr == nullptr && missing_instruction_hook_ != nullptr) { + return missing_instruction_hook_(name, shape); + } + return instr; +} + // ::= 'HloModule' name computations bool HloParser::ParseHloModule() { if (lexer_.GetKind() != TokKind::kw_HloModule) { @@ -372,8 +404,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - std::pair<HloInstruction*, LocTy>* root_node = - tensorflow::gtl::FindOrNull(instruction_pool_, root_name); + std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); // This means some instruction was marked as ROOT but we didn't find it in the // pool, which should not happen. if (!root_name.empty() && root_node == nullptr) { @@ -1525,8 +1556,7 @@ bool HloParser::ParseInstructionNames( if (!ParseName(&name)) { return Error(loc, "expects a instruction name"); } - std::pair<HloInstruction*, LocTy>* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name); if (!instr) { return TokenError( Printf("instruction '%s' is not defined", name.c_str())); @@ -2009,6 +2039,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal, // ::= operand (, operand)* // operand ::= (shape)? name bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { + CHECK(operands != nullptr); if (!ParseToken(TokKind::kLparen, "expects '(' at the beginning of operands")) { return false; @@ -2019,9 +2050,10 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { do { LocTy loc = lexer_.GetLoc(); string name; + optional<Shape> shape; if (CanBeShape()) { - Shape shape; - if (!ParseShape(&shape)) { + shape.emplace(); + if (!ParseShape(&shape.value())) { return false; } } @@ -2029,8 +2061,8 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { return false; } std::pair<HloInstruction*, LocTy>* instruction = - tensorflow::gtl::FindOrNull(instruction_pool_, name); - if (!instruction) { + FindInstruction(name, shape); + if (instruction == nullptr) { return Error(loc, StrCat("instruction does not exist: ", name)); } operands->push_back(instruction->first); @@ -2041,6 +2073,7 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands, const int expected_size) { + CHECK(operands != nullptr); LocTy loc = lexer_.GetLoc(); if (!ParseOperands(operands)) { return false; @@ -3029,6 +3062,40 @@ HloParser::ParseConvolutionDimensionNumbersOnly() { return dnums; } +Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, + string* root_name) { + TF_RET_CHECK(missing_instruction_hook_ == nullptr); + + // The missing instruction hook we register creates the shaped instruction on + // the fly as a parameter and returns it. + int64 parameter_count = 0; + missing_instruction_hook_ = + [this, builder, ¶meter_count]( + string name, + const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + StrCat("Operand ", name, + " had no shape in HLO text; cannot create parameter for " + "single-instruction module.")); + return nullptr; + } + HloInstruction* parameter = builder->AddInstruction( + HloInstruction::CreateParameter(parameter_count++, *shape, name)); + instruction_pool_[name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(instruction_pool_, name); + }; + + // Prime the lexer. + lexer_.Lex(); + + // Parse the instruction with the registered hook. + if (!ParseInstruction(builder, root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError().c_str()); + } + return Status::OK(); +} + } // namespace StatusOr<std::unique_ptr<HloModule>> ParseHloString( @@ -3046,6 +3113,19 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString( return ParseHloString(str, config); } +StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( + tensorflow::StringPiece str, tensorflow::StringPiece name) { + HloModuleConfig config; + HloParser parser(str, config); + auto builder = absl::make_unique<HloComputation::Builder>(name.ToString()); + string root_name; + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); + std::unique_ptr<HloComputation> computation = builder->Build(); + auto module = absl::make_unique<HloModule>(name.ToString(), config); + module->AddEntryComputation(std::move(computation)); + return std::move(module); +} + StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) { HloModuleConfig config; HloParser parser(str, config); |