aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 19:28:27 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 19:32:18 -0700
commitfa61b939bec50d731b86f40c79054503d629e29b (patch)
tree17a6bcedbe4878fc81014a0c3a2f77579ecb6241 /tensorflow/compiler
parent8dc7bc7764150253c03a666eee84fc48f867d6a2 (diff)
[XLA] Merge the single instruction parsing and the full module parsing in one function.
PiperOrigin-RevId: 215501702
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc66
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc22
3 files changed, 45 insertions, 49 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 5a125b4c08..0440f1b54f 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -68,7 +68,7 @@ class HloParser {
// Runs the parser and constructs the resulting HLO in the given (empty)
// HloModule. Returns false if an error occurred.
- bool Run(HloModule* module);
+ Status Run(HloModule* module);
// Returns the error information.
string GetError() const { return StrJoin(error_, "\n"); }
@@ -79,9 +79,6 @@ class HloParser {
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
- // Stand-alone parsing utility for a single instruction worth of text.
- Status ParseSingleInstruction(HloModule* module);
-
private:
using InstrNameTable =
std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
@@ -100,8 +97,12 @@ class HloParser {
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
+ // Parse a single instruction worth of text.
+ bool ParseSingleInstruction(HloModule* module);
+
// ParseXXX returns false if an error occurred.
bool ParseHloModule(HloModule* module);
+
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
bool ParseInstructionList(HloComputation** computation,
@@ -376,9 +377,25 @@ bool HloParser::TokenError(absl::string_view msg) {
return Error(lexer_.GetLoc(), msg);
}
-bool HloParser::Run(HloModule* module) {
+Status HloParser::Run(HloModule* module) {
lexer_.Lex();
- return ParseHloModule(module);
+ if (lexer_.GetKind() == TokKind::kw_HloModule) {
+ // This means that the text contains a full HLO module.
+ if (!ParseHloModule(module)) {
+ return InvalidArgument(
+ "Syntax error when trying to parse the text as a HloModule:\n%s",
+ GetError());
+ }
+ return Status::OK();
+ }
+ // This means that the text is a single HLO instruction.
+ if (!ParseSingleInstruction(module)) {
+ return InvalidArgument(
+ "Syntax error when trying to parse the text as single "
+ "HloInstruction:\n%s",
+ GetError());
+ }
+ return Status::OK();
}
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
@@ -3279,9 +3296,11 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
return padding_config;
}
-Status HloParser::ParseSingleInstruction(HloModule* module) {
- TF_RET_CHECK(create_missing_instruction_ == nullptr);
- TF_RET_CHECK(scoped_name_tables_.empty());
+bool HloParser::ParseSingleInstruction(HloModule* module) {
+ if (create_missing_instruction_ != nullptr || !scoped_name_tables_.empty()) {
+ LOG(FATAL) << "Parser state is not clean. Please do not call any other "
+ "methods before calling ParseSingleInstruction.";
+ }
HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
@@ -3298,9 +3317,6 @@ Status HloParser::ParseSingleInstruction(HloModule* module) {
return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
- // Prime the lexer.
- lexer_.Lex();
-
// Parse the instruction with the registered hook.
Scope scope(&scoped_name_tables_);
if (CanBeShape()) {
@@ -3309,7 +3325,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) {
//
// f32[10] fusion(...), calls={...}
if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ return false;
}
} else {
// This means that the instruction's left-hand side might exist, e.g.
@@ -3317,7 +3333,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) {
// foo = f32[10] fusion(...), calls={...}
string root_name;
if (!ParseInstruction(&builder, &root_name)) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ return false;
}
}
@@ -3325,7 +3341,7 @@ Status HloParser::ParseSingleInstruction(HloModule* module) {
for (auto& comp : computations_) {
module->AddEmbeddedComputation(std::move(comp));
}
- return Status::OK();
+ return true;
}
} // namespace
@@ -3334,38 +3350,24 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
absl::string_view str, const HloModuleConfig& config) {
auto module = absl::make_unique<HloModule>(/*name=*/"", config);
HloParser parser(str);
- if (!parser.Run(module.get())) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError());
- }
+ TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str) {
auto module = absl::make_unique<HloModule>(/*name=*/"", HloModuleConfig());
HloParser parser(str);
- if (!parser.Run(module.get())) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError());
- }
+ TF_RETURN_IF_ERROR(parser.Run(module.get()));
return std::move(module);
}
Status ParseHloString(absl::string_view str, HloModule* module) {
TF_RET_CHECK(module->computation_count() == 0);
HloParser parser(str);
- if (!parser.Run(module)) {
- return InvalidArgument("Syntax error:\n%s", parser.GetError());
- }
+ TF_RETURN_IF_ERROR(parser.Run(module));
return Status::OK();
}
-StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- absl::string_view str, absl::string_view name) {
- HloParser parser(str);
- auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
- TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get()));
- return std::move(module);
-}
-
StatusOr<HloSharding> ParseSharding(absl::string_view str) {
HloParser parser(str);
return parser.ParseShardingOnly();
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 97d6f0117e..81eeb9f13b 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -40,12 +40,6 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
// point to an empty module (no computations).
Status ParseHloString(absl::string_view str, HloModule* module);
-// Parses the text for a single HLO instruction into an HLO module with an
-// entry computation that runs that instruction (with the same parameters) as
-// its root instruction.
-StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
- absl::string_view str, absl::string_view name = "single_op");
-
// Given a string in the HloModule::ToString() format, parses the string and
// creates a HloModule with default config.
StatusOr<std::unique_ptr<HloModule>> ParseHloString(absl::string_view str);
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index d10acf3814..b618510640 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1835,7 +1835,7 @@ TEST(HloParserSingleOpTest, SingleOp) {
const string text =
"%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0} %broadcast, "
"f32[2,4]{1,0} %x)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
@@ -1844,7 +1844,7 @@ TEST(HloParserSingleOpTest, SingleOp) {
TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
- StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
EXPECT_THAT(module.status().ToString(),
@@ -1853,7 +1853,7 @@ TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
- StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloString(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
EXPECT_THAT(module.status().ToString(),
@@ -1863,7 +1863,7 @@ TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
TEST(HloParserSingleOpTest, SingleOpNoNames) {
const string text =
"%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
@@ -1872,7 +1872,7 @@ TEST(HloParserSingleOpTest, SingleOpNoNames) {
TEST(HloParserSingleOpTest, CanonicalOp) {
const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
@@ -1908,7 +1908,7 @@ TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
}
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_EQ(
@@ -1926,7 +1926,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested) {
ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
})";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),
@@ -1939,7 +1939,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
{
result = f32[] add(f32[] x, f32[] y)
})";
- auto status = ParseHloOpToModule(text).status();
+ auto status = ParseHloString(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(),
::testing::HasSubstr("does not exist: x"));
@@ -1951,7 +1951,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
{
f32[] add(f32[] x, f32[] y)
})";
- auto status = ParseHloOpToModule(text).status();
+ auto status = ParseHloString(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
@@ -1962,7 +1962,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
{
result = f32[] add(f32[], f32[])
})";
- auto status = ParseHloOpToModule(text).status();
+ auto status = ParseHloString(text).status();
ASSERT_FALSE(status.ok());
EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
@@ -1970,7 +1970,7 @@ TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
const string text =
R"(%convolution = f32[1,2,1]{2,0,1} convolution(f32[1,2,1]{2,0,1} %copy, f32[1,1,1]{2,1,0} %filter), window={size=1}, dim_labels=b0f_0io->b0f)";
- TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(text));
const HloComputation* computation = module->entry_computation();
ASSERT_NE(computation, nullptr);
EXPECT_THAT(computation->root_instruction(),