diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2018-10-02 15:08:52 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-10-02 15:20:19 -0700 |
commit | bb84d5d5e309204110315f7d0ff8ca0dbb022dd2 (patch) | |
tree | 351ff0255434d39238315db811581cde201380c2 | |
parent | cfec3aa38db1d2b70045e7b89d82fae87c3fec02 (diff) |
[XLA] Support parsing the canonical format of HLO text.
Also stop truncating operands in the canonical format.
PiperOrigin-RevId: 215466465
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_execution_profile.cc | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 276 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser_test.cc | 142 |
6 files changed, 338 insertions, 106 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc index de3d7a1677..ce4cad4235 100644 --- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc +++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc @@ -90,8 +90,9 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData( HloInstructionInfo* instruction_info = computation_info->add_instruction_infos(); instruction_info->set_long_name(hlo->ToString()); - instruction_info->set_short_name( - hlo->ToString(HloPrintOptions().set_compact_operands(true))); + instruction_info->set_short_name(hlo->ToString( + HloPrintOptions().set_compact_operands(true).set_print_operand_names( + false))); instruction_info->set_category(hlo->ToCategory()); instruction_info->set_flop_count(cost_analysis.flop_count(*hlo)); instruction_info->set_transcendental_count( diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 5c16d6bb5e..8bddaa8c96 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2034,7 +2034,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap( options.is_in_nested_computation()) { str.push_back(PrintName( canonical_name_map->LookupOrInsert(operand->name()), options)); - } else if (!options.compact_operands()) { + } else if (options.print_operand_names()) { str.push_back(PrintName(operand->name(), options)); } StrAppend(out, StrJoin(str, " ")); diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 1bfdc88abc..9deed20e5d 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -80,6 +80,7 @@ class HloPrintOptions { print_backend_config_(true), compact_operands_(false), print_operand_shape_(true), + print_operand_names_(true), print_program_shape_(true), print_percent_(true), print_control_dependencies_(true), @@ -107,6 +108,7 @@ class HloPrintOptions { .set_print_metadata(false) .set_print_backend_config(false) .set_compact_operands(true) + .set_print_operand_names(false) .set_print_operand_shape(true) .set_print_program_shape(false) .set_print_percent(false) @@ -144,6 +146,12 @@ class HloPrintOptions { return *this; } + // If true, the operand names will be printed. + HloPrintOptions& set_print_operand_names(bool value) { + print_operand_names_ = value; + return *this; + } + // If true, program shape of hlo computations will be printed. HloPrintOptions& set_print_program_shape(bool value) { print_program_shape_ = value; @@ -162,8 +170,8 @@ class HloPrintOptions { return *this; } - // If true, only a part of operands will be printed out, and their names will - // be omitted (note that in this case the text will not be parsable). + // If true, only a part of operands will be printed out (note that in this + // case the text will not be parsable). HloPrintOptions& set_compact_operands(bool value) { compact_operands_ = value; return *this; @@ -197,6 +205,7 @@ class HloPrintOptions { bool print_backend_config() const { return print_backend_config_; } bool compact_operands() const { return compact_operands_; } bool print_operand_shape() const { return print_operand_shape_; } + bool print_operand_names() const { return print_operand_names_; } bool print_program_shape() const { return print_program_shape_; } bool print_percent() const { return print_percent_; } bool print_control_dependencies() const { @@ -215,6 +224,7 @@ class HloPrintOptions { bool print_backend_config_; bool compact_operands_; bool print_operand_shape_; + bool print_operand_names_; bool print_program_shape_; bool print_percent_; bool print_control_dependencies_; diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index 25b70740e3..5a125b4c08 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -80,17 +80,23 @@ class HloParser { StatusOr<PaddingConfig> ParsePaddingConfigOnly(); // Stand-alone parsing utility for a single instruction worth of text. - Status ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name); + Status ParseSingleInstruction(HloModule* module); private: - // Locates an instruction with the given name in the instruction_pool_ or + using InstrNameTable = + std::unordered_map<string, std::pair<HloInstruction*, LocTy>>; + + // Returns the map from the instruction name to the instruction itself and its + // location in the current scope. + InstrNameTable& current_name_table() { return scoped_name_tables_.back(); } + + // Locates an instruction with the given name in the current_name_table() or // returns nullptr. // - // If the missing_instruction_hook_ is registered and a "shape" is provided, - // the hook will be called and may satisfy the request for the given - // instruction. This is useful when we reify parameters as they're resolved; - // i.e. for ParseSingleInstruction. + // When the name is not found or name is empty, if create_missing_instruction_ + // hook is registered and a "shape" is provided, the hook will be called to + // create an instruction. This is useful when we reify parameters as they're + // resolved; i.e. for ParseSingleInstruction. std::pair<HloInstruction*, LocTy>* FindInstruction( const string& name, const optional<Shape>& shape = nullopt); @@ -98,9 +104,11 @@ class HloParser { bool ParseHloModule(HloModule* module); bool ParseComputations(HloModule* module); bool ParseComputation(HloComputation** entry_computation); - bool ParseInstructionList(HloComputation::Builder* builder, - string* root_name); + bool ParseInstructionList(HloComputation** computation, + const string& computation_name); bool ParseInstruction(HloComputation::Builder* builder, string* root_name); + bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name, + LocTy name_loc); bool ParseControlPredecessors(HloInstruction* instruction); bool ParseLiteral(Literal* literal, const Shape& shape); bool ParseTupleLiteral(Literal* literal, const Shape& shape); @@ -281,23 +289,47 @@ class HloParser { bool AddComputation(const string& name, HloComputation* computation, LocTy name_loc); - // The map from the instruction/computation name to the - // instruction/computation itself and it's location. This does not own the - // pointers. - std::unordered_map<string, std::pair<HloInstruction*, LocTy>> - instruction_pool_; + HloLexer lexer_; + + // A stack for the instruction names. The top of the stack stores the + // instruction name table for the current scope. + // + // A instruction's name is unique among its scope (i.e. its parent + // computation), but it's not necessarily unique among all computations in the + // module. When there are multiple levels of nested computations, the same + // name could appear in both an outer computation and an inner computation. So + // we need a stack to make sure a name is only visible within its scope, + std::vector<InstrNameTable> scoped_name_tables_; + + // A helper class which pushes and pops to an InstrNameTable stack via RAII. + class Scope { + public: + explicit Scope(std::vector<InstrNameTable>* scoped_name_tables) + : scoped_name_tables_(scoped_name_tables) { + scoped_name_tables_->emplace_back(); + } + ~Scope() { scoped_name_tables_->pop_back(); } + + private: + std::vector<InstrNameTable>* scoped_name_tables_; + }; + + // Map from the computation name to the computation itself and its location. std::unordered_map<string, std::pair<HloComputation*, LocTy>> computation_pool_; - HloLexer lexer_; std::vector<std::unique_ptr<HloComputation>> computations_; std::vector<string> error_; - // Function that gets invoked when we try to resolve an instruction - // instruction_pool_ but fail to do so. - std::function<std::pair<HloInstruction*, LocTy>*(string, - const optional<Shape>&)> - missing_instruction_hook_; + // When an operand name cannot be resolved, this function is called to create + // a parameter instruction with the given name and shape. It registers the + // name, instruction, and a placeholder location in the name table. It returns + // the newly-created instruction and the placeholder location. If `name` is + // empty, this should create the parameter with a generated name. This is + // supposed to be set and used only in ParseSingleInstruction. + std::function<std::pair<HloInstruction*, LocTy>*(const string& name, + const Shape& shape)> + create_missing_instruction_; }; bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) { @@ -351,11 +383,21 @@ bool HloParser::Run(HloModule* module) { std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction( const string& name, const optional<Shape>& shape) { - std::pair<HloInstruction*, LocTy>* instr = - tensorflow::gtl::FindOrNull(instruction_pool_, name); + std::pair<HloInstruction*, LocTy>* instr = nullptr; + if (!name.empty()) { + instr = tensorflow::gtl::FindOrNull(current_name_table(), name); + } + // Potentially call the missing instruction hook. - if (instr == nullptr && missing_instruction_hook_ != nullptr) { - return missing_instruction_hook_(name, shape); + if (instr == nullptr && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + if (!shape.has_value()) { + Error(lexer_.GetLoc(), + "Operand had no shape in HLO text; cannot create parameter for " + "single-instruction module."); + return nullptr; + } + return create_missing_instruction_(name, *shape); } return instr; } @@ -439,7 +481,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { if (!ParseName(&name)) { return false; } - auto builder = absl::make_unique<HloComputation::Builder>(name); LocTy shape_loc = nullptr; Shape shape; @@ -447,40 +488,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { return false; } - string root_name; - if (!ParseInstructionList(builder.get(), &root_name)) { + HloComputation* computation = nullptr; + if (!ParseInstructionList(&computation, name)) { return false; } - std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name); - // This means some instruction was marked as ROOT but we didn't find it in the - // pool, which should not happen. - if (!root_name.empty() && root_node == nullptr) { - LOG(FATAL) << "instruction " << root_name - << " was marked as ROOT but the parser has not seen it before"; - } - - HloInstruction* root = root_node == nullptr ? nullptr : root_node->first; - // Now root can be either an existing instruction or a nullptr. If it's a - // nullptr, the implementation of Builder will set the last instruction as - // root instruction. - computations_.emplace_back(builder->Build(root)); - HloComputation* computation = computations_.back().get(); - - if (!root) { - root = computation->root_instruction(); - } else { - CHECK_EQ(root, computation->root_instruction()); - } - // If param_list_to_shape was present, check compatibility. - if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) { + if (shape_loc != nullptr && + !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) { return Error( shape_loc, - StrCat("Shape of computation ", name, ", ", - ShapeUtil::HumanString(shape), - ", is not compatible with that of its root instruction ", - root_name, ", ", ShapeUtil::HumanString(root->shape()))); + StrCat( + "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape), + ", is not compatible with that of its root instruction ", + computation->root_instruction()->name(), ", ", + ShapeUtil::HumanString(computation->root_instruction()->shape()))); } if (is_entry_computation) { @@ -489,43 +511,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) { } *entry_computation = computation; } - instruction_pool_.clear(); return AddComputation(name, computation, name_loc); } // instruction_list ::= '{' instruction_list1 '}' // instruction_list1 ::= (instruction)+ -bool HloParser::ParseInstructionList(HloComputation::Builder* builder, - string* root_name) { +bool HloParser::ParseInstructionList(HloComputation** computation, + const string& computation_name) { + Scope scope(&scoped_name_tables_); + HloComputation::Builder builder(computation_name); if (!ParseToken(TokKind::kLbrace, "expects '{' at the beginning of instruction list.")) { return false; } + string root_name; do { - if (!ParseInstruction(builder, root_name)) { + if (!ParseInstruction(&builder, &root_name)) { return false; } } while (lexer_.GetKind() != TokKind::kRbrace); - return ParseToken(TokKind::kRbrace, - "expects '}' at the end of instruction list."); + if (!ParseToken(TokKind::kRbrace, + "expects '}' at the end of instruction list.")) { + return false; + } + HloInstruction* root = nullptr; + if (!root_name.empty()) { + std::pair<HloInstruction*, LocTy>* root_node = + tensorflow::gtl::FindOrNull(current_name_table(), root_name); + + // This means some instruction was marked as ROOT but we didn't find it in + // the pool, which should not happen. + if (root_node == nullptr) { + LOG(FATAL) << "instruction " << root_name + << " was marked as ROOT but the parser has not seen it before"; + } + root = root_node->first; + } + + // Now root can be either an existing instruction or a nullptr. If it's a + // nullptr, the implementation of Builder will set the last instruction as + // the root instruction. + computations_.emplace_back(builder.Build(root)); + *computation = computations_.back().get(); + return true; } // instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)* bool HloParser::ParseInstruction(HloComputation::Builder* builder, string* root_name) { string name; - Shape shape; - HloOpcode opcode; - std::vector<HloInstruction*> operands; - LocTy maybe_root_loc = lexer_.GetLoc(); bool is_root = EatIfPresent(TokKind::kw_ROOT); const LocTy name_loc = lexer_.GetLoc(); if (!ParseName(&name) || - !ParseToken(TokKind::kEqual, "expects '=' in instruction") || - !ParseShape(&shape) || !ParseOpcode(&opcode)) { + !ParseToken(TokKind::kEqual, "expects '=' in instruction")) { return false; } @@ -536,6 +577,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, *root_name = name; } + return ParseInstruciontRhs(builder, name, name_loc); +} + +bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder, + const string& name, LocTy name_loc) { + Shape shape; + HloOpcode opcode; + std::vector<HloInstruction*> operands; + + if (!ParseShape(&shape) || !ParseOpcode(&opcode)) { + return false; + } + // Add optional attributes. std::unordered_map<string, AttrConfig> attrs; optional<OpSharding> sharding; @@ -2146,7 +2200,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) { } } if (!ParseName(&name)) { - return false; + // When parsing a single instruction (as opposed to a whole module), an + // HLO may have one or more operands with a shape but no name: + // + // foo = add(f32[10], f32[10]) + // + // create_missing_instruction_ is always non-null when parsing a single + // instruction, and is responsible for creating kParameter instructions + // for these operands. + if (shape.has_value() && create_missing_instruction_ != nullptr && + scoped_name_tables_.size() == 1) { + name = ""; + } else { + return false; + } } std::pair<HloInstruction*, LocTy>* instruction = FindInstruction(name, shape); @@ -2299,9 +2366,17 @@ bool HloParser::ParseAttributeHelper( return true; } case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; + HloComputation* result = nullptr; + if (lexer_.GetKind() == TokKind::kLbrace) { + // This means it is a nested computation. + if (!ParseInstructionList(&result, /*computation_name=*/"_")) { + return false; + } + } else { + // This means it is a computation name. + if (!ParseComputationName(&result)) { + return false; + } } static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); return true; @@ -3134,7 +3209,7 @@ bool HloParser::EatIfPresent(TokKind kind) { bool HloParser::AddInstruction(const string& name, HloInstruction* instruction, LocTy name_loc) { - auto result = instruction_pool_.insert({name, {instruction, name_loc}}); + auto result = current_name_table().insert({name, {instruction, name_loc}}); if (!result.second) { Error(name_loc, StrCat("instruction already exists: ", name)); return Error(/*loc=*/result.first->second.second, @@ -3204,36 +3279,51 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() { return padding_config; } -Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder, - string* root_name) { - TF_RET_CHECK(missing_instruction_hook_ == nullptr); +Status HloParser::ParseSingleInstruction(HloModule* module) { + TF_RET_CHECK(create_missing_instruction_ == nullptr); + TF_RET_CHECK(scoped_name_tables_.empty()); + HloComputation::Builder builder(module->name()); // The missing instruction hook we register creates the shaped instruction on // the fly as a parameter and returns it. int64 parameter_count = 0; - missing_instruction_hook_ = - [this, builder, ¶meter_count]( - string name, - const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* { - if (!shape.has_value()) { - Error(lexer_.GetLoc(), - StrCat("Operand ", name, - " had no shape in HLO text; cannot create parameter for " - "single-instruction module.")); - return nullptr; - } - HloInstruction* parameter = builder->AddInstruction( - HloInstruction::CreateParameter(parameter_count++, *shape, name)); - instruction_pool_[name] = {parameter, lexer_.GetLoc()}; - return tensorflow::gtl::FindOrNull(instruction_pool_, name); + create_missing_instruction_ = + [this, &builder, ¶meter_count]( + const string& name, + const Shape& shape) -> std::pair<HloInstruction*, LocTy>* { + string new_name = name.empty() ? StrCat("_", parameter_count) : name; + HloInstruction* parameter = builder.AddInstruction( + HloInstruction::CreateParameter(parameter_count++, shape, new_name)); + current_name_table()[new_name] = {parameter, lexer_.GetLoc()}; + return tensorflow::gtl::FindOrNull(current_name_table(), new_name); }; // Prime the lexer. lexer_.Lex(); // Parse the instruction with the registered hook. - if (!ParseInstruction(builder, root_name)) { - return InvalidArgument("Syntax error:\n%s", GetError()); + Scope scope(&scoped_name_tables_); + if (CanBeShape()) { + // This means that the instruction's left-hand side is probably omitted, + // e.g. + // + // f32[10] fusion(...), calls={...} + if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + } else { + // This means that the instruction's left-hand side might exist, e.g. + // + // foo = f32[10] fusion(...), calls={...} + string root_name; + if (!ParseInstruction(&builder, &root_name)) { + return InvalidArgument("Syntax error:\n%s", GetError()); + } + } + + module->AddEntryComputation(builder.Build()); + for (auto& comp : computations_) { + module->AddEmbeddedComputation(std::move(comp)); } return Status::OK(); } @@ -3271,12 +3361,8 @@ Status ParseHloString(absl::string_view str, HloModule* module) { StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( absl::string_view str, absl::string_view name) { HloParser parser(str); - auto builder = absl::make_unique<HloComputation::Builder>(string(name)); - string root_name; - TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name)); - std::unique_ptr<HloComputation> computation = builder->Build(); auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig()); - module->AddEntryComputation(std::move(computation)); + TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get())); return std::move(module); } diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h index 3696035514..97d6f0117e 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.h +++ b/tensorflow/compiler/xla/service/hlo_parser.h @@ -40,8 +40,9 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString( // point to an empty module (no computations). Status ParseHloString(absl::string_view str, HloModule* module); -// Parses the text for a single HLO operation into an HLO module with a function -// that runs that operation (with the same parameters) as its entry computation. +// Parses the text for a single HLO instruction into an HLO module with an +// entry computation that runs that instruction (with the same parameters) as +// its root instruction. StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule( absl::string_view str, absl::string_view name = "single_op"); diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc index dd4ee780f0..d10acf3814 100644 --- a/tensorflow/compiler/xla/service/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc @@ -1763,6 +1763,25 @@ ENTRY entry { "was parsing 8:39: error: instruction does not exist: aparam"); } +TEST_F(HloParserTest, SameNameDiffComputations) { + const string original = R"(HloModule same_names: +add { + p0 = f32[] parameter(0) + p1 = f32[] parameter(1) + ROOT result = f32[] add(p0, p1) +} + +ENTRY ReduceR3ToR2 { + p0 = f32[8,16,256]{2,1,0} parameter(0) + p1 = f32[] constant(0) + ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add +} +)"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original)); + ASSERT_NE(module->entry_computation(), nullptr); + EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce()); +} + TEST_F(HloParserTest, ParseSharding) { const string original = "{maximal device=42}"; TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original)); @@ -1823,14 +1842,129 @@ TEST(HloParserSingleOpTest, SingleOp) { op::Multiply(op::Parameter(0), op::Parameter(1))); } -TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) { +TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) { + const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)"; + StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); + ASSERT_TRUE(!module.status().ok()); + LOG(INFO) << "Status: " << module.status(); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("expects '=' in instruction")); +} + +TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) { const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)"; StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text); ASSERT_TRUE(!module.status().ok()); LOG(INFO) << "Status: " << module.status(); - EXPECT_THAT( - module.status().ToString(), - ::testing::HasSubstr("Operand broadcast had no shape in HLO text")); + EXPECT_THAT(module.status().ToString(), + ::testing::HasSubstr("Operand had no shape in HLO text")); +} + +TEST(HloParserSingleOpTest, SingleOpNoNames) { + const string text = + "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, CanonicalOp) { + const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})"; + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Multiply(op::Parameter(0), op::Parameter(1))); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, CanonicalOpWithNested) { + const string text = + R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +}, body= +{ + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls= + { + tmp_0 = f32[5,10]{1,0} parameter(0) + tmp_1 = f32[20,10]{1,0} parameter(1) + tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0} + ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0} + } +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_EQ( + computation->root_instruction()->ToString(HloPrintOptions::Canonical()), + text); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested) { + const string text = + R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls= +{ + %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0) + %param_1 = f32[2]{0} parameter(1) + %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1} + ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast) +})"; + + TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text)); + const HloComputation* computation = module->entry_computation(); + ASSERT_NE(computation, nullptr); + EXPECT_THAT(computation->root_instruction(), + op::Fusion(op::Parameter(0), op::Parameter(1))); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloOpToModule(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), + ::testing::HasSubstr("does not exist: x")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + f32[] add(f32[] x, f32[] y) +})"; + auto status = ParseHloOpToModule(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); +} + +TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) { + const string text = + R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply= +{ + result = f32[] add(f32[], f32[]) +})"; + auto status = ParseHloOpToModule(text).status(); + ASSERT_FALSE(status.ok()); + EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name")); } TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) { |