diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 19:28:27 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 19:32:18 -0700 |
commit | fa61b939bec50d731b86f40c79054503d629e29b (patch) | |
tree | 17a6bcedbe4878fc81014a0c3a2f77579ecb6241 /tensorflow/compiler/xla/service | |
parent | 8dc7bc7764150253c03a666eee84fc48f867d6a2 (diff) |
[XLA] Merge the single instruction parsing and the full module parsing in one function.
PiperOrigin-RevId: 215501702
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 66 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.h | 6 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 22 |
3 files changed, 45 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 5a125b4c08..0440f1b54f 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -68,7 +68,7 @@ class HloParser { // Runs the parser and constructs the resulting HLO in the given (empty) // HloModule. Returns false if an error occurred. - bool Run(HloModule* module); + Status Run(HloModule* module); // Returns the error information. string GetError() const { return StrJoin(error_, "\n"); } @@ -79,9 +79,6 @@ class HloParser { StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly(); StatusOr<PaddingConfig> ParsePaddingConfigOnly(); - // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloModule* module); - private: using InstrNameTable = std::unordered_map<string, std::pair<HloInstruction*, LocTy>>; @@ -100,8 +97,12 @@ class HloParser { std::pair<HloInstruction*, LocTy>* FindInstruction( const string& name, const optional<Shape>& shape = nullopt); + // Parse a single instruction worth of text. + bool ParseSingleInstruction(HloModule* module); + // ParseXXX returns false if an error occurred. bool ParseHloModule(HloModule* module); + bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); bool ParseInstructionList(HloComputation** computation, @@ -376,9 +377,25 @@ bool HloParser::TokenError(absl::string_view msg) { return Error(lexer_.GetLoc(), msg); } -bool HloParser::Run(HloModule* module) { +Status HloParser::Run(HloModule* module) { lexer_.Lex(); - return ParseHloModule(module); + if (lexer_.GetKind() == TokKind::kw_HloModule) { + // This means that the text contains a full HLO module. + if (!ParseHloModule(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as a HloModule:\n%s", + GetError()); + } + return Status::OK(); + } + // This means that the text is a single HLO instruction. + if (!ParseSingleInstruction(module)) { + return InvalidArgument( + "Syntax error when trying to parse the text as single " + "HloInstruction:\n%s", + GetError()); + } + return Status::OK(); } std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( @@ -3279,9 +3296,11 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloModule* module) { - TF_RET_CHECK(create_missing_instruction_ == nullptr); - TF_RET_CHECK(scoped_name_tables_.empty()); +bool HloParser::ParseSingleInstruction(HloModule* module) { + if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) { + LOG(FATAL) << "Parser state is not clean. Please do not call any other " + "methods before calling ParseSingleInstruction."; + } HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on @@ -3298,9 +3317,6 @@ Status HloParser::ParseSingleInstruction(HloModule* module) { return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; - // Prime the lexer. - lexer_.Lex(); - // Parse the instruction with the registered hook. Scope scope(&scoped_name_tables_); if (CanBeShape()) { @@ -3309,7 +3325,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) { // // f32[10] fusion(...), calls={...} if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { - return InvalidArgument("Syntax error:\n%s", GetError()); + return false; } } else { // This means that the instruction's left-hand side might exist, e.g. @@ -3317,7 +3333,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) { // foo = f32[10] fusion(...), calls={...} string root_name; if (!ParseInstruction(&builder, &root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + return false; } } @@ -3325,7 +3341,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) { for (auto& comp : computations_) { module->AddEmbeddedComputation(std::move(comp)); } - return Status::OK(); + return true; } } // namespace @@ -3334,38 +3350,24 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString( absl::string_view str, const HloModuleConfig& config) { auto module = absl::make_unique<HloModule>(/*name=*/"", config); HloParser parser(str); - if (!parser.Run(module.get())) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) { auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig()); HloParser parser(str); - if (!parser.Run(module.get())) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module.get())); return std::move(module); } Status ParseHloString(absl::string_view str, HloModule* module) { TF_RET_CHECK(module->computation_count() == 0); HloParser parser(str); - if (!parser.Run(module)) { - return InvalidArgument("Syntax error:\n%s", parser.GetError()); - } + TF_RETURN_IF_ERROR(parser.Run(module)); return Status::OK(); } -StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - absl::string_view str, absl::string_view name) { - HloParser parser(str); - auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig()); - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get())); - return std::move(module); -} - StatusOr<HloSharding> ParseSharding(absl::string_view str) { HloParser parser(str); return parser.ParseShardingOnly(); diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 97d6f0117e..81eeb9f13b 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -40,12 +40,6 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString( // point to an empty module (no computations). Status ParseHloString(absl::string_view str, HloModule* module); -// Parses the text for a single HLO instruction into an HLO module with an -// entry computation that runs that instruction (with the same parameters) as -// its root instruction. -StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( - absl::string_view str, absl::string_view name = "single_op"); - // Given a string in the HloModule::ToString() format, parses the string and // creates a HloModule with default config. StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index d10acf3814..b618510640 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1835,7 +1835,7 @@ TEST(HloParserSingleOpTest, SingleOp) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, " "f32[2,4]{1,0} %x)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1844,7 +1844,7 @@ TEST(HloParserSingleOpTest, SingleOp) { TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; - StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); + StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); EXPECT_THAT(module.status().ToString(), @@ -1853,7 +1853,7 @@ TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; - StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); + StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); EXPECT_THAT(module.status().ToString(), @@ -1863,7 +1863,7 @@ TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { TEST(HloParserSingleOpTest, SingleOpNoNames) { const string text = "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1872,7 +1872,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) { TEST(HloParserSingleOpTest, CanonicalOp) { const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1908,7 +1908,7 @@ TEST(HloParserSingleOpTest, CanonicalOpWithNested) { } })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_EQ( @@ -1926,7 +1926,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) { ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) })"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), @@ -1939,7 +1939,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { { result = f32[] add(f32[] x, f32[] y) })"; - auto status = ParseHloOpToModule(text).status(); + auto status = ParseHloString(text).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), ::testing::HasSubstr("does not exist: x")); @@ -1951,7 +1951,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { { f32[] add(f32[] x, f32[] y) })"; - auto status = ParseHloOpToModule(text).status(); + auto status = ParseHloString(text).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } @@ -1962,7 +1962,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { { result = f32[] add(f32[], f32[]) })"; - auto status = ParseHloOpToModule(text).status(); + auto status = ParseHloString(text).status(); ASSERT_FALSE(status.ok()); EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } @@ -1970,7 +1970,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { const string text = R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)"; - TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text)); const HloComputation* computation = module->entry_computation(); ASSERT_NE(computation, nullptr); EXPECT_THAT(computation->root_instruction(), |