aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 15:08:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 15:20:19 -0700
commitbb84d5d5e309204110315f7d0ff8ca0dbb022dd2 (patch)
tree351ff0255434d39238315db811581cde201380c2
parentcfec3aa38db1d2b70045e7b89d82fae87c3fec02 (diff)
[XLA] Support parsing the canonical format of HLO text.
Also stop truncating operands in the canonical format. PiperOrigin-RevId: 215466465
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc276
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc142
6 files changed, 338 insertions, 106 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index de3d7a1677..ce4cad4235 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -90,8 +90,9 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
HloInstructionInfo* instruction_info =
computation_info->add_instruction_infos();
instruction_info->set_long_name(hlo->ToString());
- instruction_info->set_short_name(
- hlo->ToString(HloPrintOptions().set_compact_operands(true)));
+ instruction_info->set_short_name(hlo->ToString(
+ HloPrintOptions().set_compact_operands(true).set_print_operand_names(
+ false)));
instruction_info->set_category(hlo->ToCategory());
instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
instruction_info->set_transcendental_count(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 5c16d6bb5e..8bddaa8c96 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -2034,7 +2034,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
options.is_in_nested_computation()) {
str.push_back(PrintName(
canonical_name_map->LookupOrInsert(operand->name()), options));
- } else if (!options.compact_operands()) {
+ } else if (options.print_operand_names()) {
str.push_back(PrintName(operand->name(), options));
}
StrAppend(out, StrJoin(str, " "));
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 1bfdc88abc..9deed20e5d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -80,6 +80,7 @@ class HloPrintOptions {
print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
+ print_operand_names_(true),
print_program_shape_(true),
print_percent_(true),
print_control_dependencies_(true),
@@ -107,6 +108,7 @@ class HloPrintOptions {
.set_print_metadata(false)
.set_print_backend_config(false)
.set_compact_operands(true)
+ .set_print_operand_names(false)
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
@@ -144,6 +146,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, the operand names will be printed.
+ HloPrintOptions& set_print_operand_names(bool value) {
+ print_operand_names_ = value;
+ return *this;
+ }
+
// If true, program shape of hlo computations will be printed.
HloPrintOptions& set_print_program_shape(bool value) {
print_program_shape_ = value;
@@ -162,8 +170,8 @@ class HloPrintOptions {
return *this;
}
- // If true, only a part of operands will be printed out, and their names will
- // be omitted (note that in this case the text will not be parsable).
+ // If true, only a part of operands will be printed out (note that in this
+ // case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
compact_operands_ = value;
return *this;
@@ -197,6 +205,7 @@ class HloPrintOptions {
bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
+ bool print_operand_names() const { return print_operand_names_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
bool print_control_dependencies() const {
@@ -215,6 +224,7 @@ class HloPrintOptions {
bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
+ bool print_operand_names_;
bool print_program_shape_;
bool print_percent_;
bool print_control_dependencies_;
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 25b70740e3..5a125b4c08 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -80,17 +80,23 @@ class HloParser {
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
// Stand-alone parsing utility for a single instruction worth of text.
- Status ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name);
+ Status ParseSingleInstruction(HloModule* module);
private:
- // Locates an instruction with the given name in the instruction_pool_ or
+ using InstrNameTable =
+ std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
+
+ // Returns the map from the instruction name to the instruction itself and its
+ // location in the current scope.
+ InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
+
+ // Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
- // If the missing_instruction_hook_ is registered and a "shape" is provided,
- // the hook will be called and may satisfy the request for the given
- // instruction. This is useful when we reify parameters as they're resolved;
- // i.e. for ParseSingleInstruction.
+ // When the name is not found or name is empty, if create_missing_instruction_
+ // hook is registered and a "shape" is provided, the hook will be called to
+ // create an instruction. This is useful when we reify parameters as they're
+ // resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
@@ -98,9 +104,11 @@ class HloParser {
bool ParseHloModule(HloModule* module);
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
- bool ParseInstructionList(HloComputation::Builder* builder,
- string* root_name);
+ bool ParseInstructionList(HloComputation** computation,
+ const string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name,
+ LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
@@ -281,23 +289,47 @@ class HloParser {
bool AddComputation(const string& name, HloComputation* computation,
LocTy name_loc);
- // The map from the instruction/computation name to the
- // instruction/computation itself and it's location. This does not own the
- // pointers.
- std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
- instruction_pool_;
+ HloLexer lexer_;
+
+ // A stack for the instruction names. The top of the stack stores the
+ // instruction name table for the current scope.
+ //
+ // A instruction's name is unique among its scope (i.e. its parent
+ // computation), but it's not necessarily unique among all computations in the
+ // module. When there are multiple levels of nested computations, the same
+ // name could appear in both an outer computation and an inner computation. So
+ // we need a stack to make sure a name is only visible within its scope,
+ std::vector<InstrNameTable> scoped_name_tables_;
+
+ // A helper class which pushes and pops to an InstrNameTable stack via RAII.
+ class Scope {
+ public:
+ explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
+ : scoped_name_tables_(scoped_name_tables) {
+ scoped_name_tables_->emplace_back();
+ }
+ ~Scope() { scoped_name_tables_->pop_back(); }
+
+ private:
+ std::vector<InstrNameTable>* scoped_name_tables_;
+ };
+
+ // Map from the computation name to the computation itself and its location.
std::unordered_map<string, std::pair<HloComputation*, LocTy>>
computation_pool_;
- HloLexer lexer_;
std::vector<std::unique_ptr<HloComputation>> computations_;
std::vector<string> error_;
- // Function that gets invoked when we try to resolve an instruction
- // instruction_pool_ but fail to do so.
- std::function<std::pair<HloInstruction*, LocTy>*(string,
- const optional<Shape>&)>
- missing_instruction_hook_;
+ // When an operand name cannot be resolved, this function is called to create
+ // a parameter instruction with the given name and shape. It registers the
+ // name, instruction, and a placeholder location in the name table. It returns
+ // the newly-created instruction and the placeholder location. If `name` is
+ // empty, this should create the parameter with a generated name. This is
+ // supposed to be set and used only in ParseSingleInstruction.
+ std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
+ const Shape& shape)>
+ create_missing_instruction_;
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
@@ -351,11 +383,21 @@ bool HloParser::Run(HloModule* module) {
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
const string& name, const optional<Shape>& shape) {
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = nullptr;
+ if (!name.empty()) {
+ instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
+ }
+
// Potentially call the missing instruction hook.
- if (instr == nullptr && missing_instruction_hook_ != nullptr) {
- return missing_instruction_hook_(name, shape);
+ if (instr == nullptr && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ "Operand had no shape in HLO text; cannot create parameter for "
+ "single-instruction module.");
+ return nullptr;
+ }
+ return create_missing_instruction_(name, *shape);
}
return instr;
}
@@ -439,7 +481,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -447,40 +488,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- string root_name;
- if (!ParseInstructionList(builder.get(), &root_name)) {
+ HloComputation* computation = nullptr;
+ if (!ParseInstructionList(&computation, name)) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
- // This means some instruction was marked as ROOT but we didn't find it in the
- // pool, which should not happen.
- if (!root_name.empty() && root_node == nullptr) {
- LOG(FATAL) << "instruction " << root_name
- << " was marked as ROOT but the parser has not seen it before";
- }
-
- HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
- // Now root can be either an existing instruction or a nullptr. If it's a
- // nullptr, the implementation of Builder will set the last instruction as
- // root instruction.
- computations_.emplace_back(builder->Build(root));
- HloComputation* computation = computations_.back().get();
-
- if (!root) {
- root = computation->root_instruction();
- } else {
- CHECK_EQ(root, computation->root_instruction());
- }
-
// If param_list_to_shape was present, check compatibility.
- if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
+ if (shape_loc != nullptr &&
+ !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
- StrCat("Shape of computation ", name, ", ",
- ShapeUtil::HumanString(shape),
- ", is not compatible with that of its root instruction ",
- root_name, ", ", ShapeUtil::HumanString(root->shape())));
+ StrCat(
+ "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
+ ", is not compatible with that of its root instruction ",
+ computation->root_instruction()->name(), ", ",
+ ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
@@ -489,43 +511,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
}
*entry_computation = computation;
}
- instruction_pool_.clear();
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
-bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
- string* root_name) {
+bool HloParser::ParseInstructionList(HloComputation** computation,
+ const string& computation_name) {
+ Scope scope(&scoped_name_tables_);
+ HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
+ string root_name;
do {
- if (!ParseInstruction(builder, root_name)) {
+ if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
- return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of instruction list.");
+ if (!ParseToken(TokKind::kRbrace,
+ "expects '}' at the end of instruction list.")) {
+ return false;
+ }
+ HloInstruction* root = nullptr;
+ if (!root_name.empty()) {
+ std::pair<HloInstruction*, LocTy>* root_node =
+ tensorflow::gtl::FindOrNull(current_name_table(), root_name);
+
+ // This means some instruction was marked as ROOT but we didn't find it in
+ // the pool, which should not happen.
+ if (root_node == nullptr) {
+ LOG(FATAL) << "instruction " << root_name
+ << " was marked as ROOT but the parser has not seen it before";
+ }
+ root = root_node->first;
+ }
+
+ // Now root can be either an existing instruction or a nullptr. If it's a
+ // nullptr, the implementation of Builder will set the last instruction as
+ // the root instruction.
+ computations_.emplace_back(builder.Build(root));
+ *computation = computations_.back().get();
+ return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
- Shape shape;
- HloOpcode opcode;
- std::vector<HloInstruction*> operands;
-
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
- !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
- !ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
@@ -536,6 +577,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
*root_name = name;
}
+ return ParseInstruciontRhs(builder, name, name_loc);
+}
+
+bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
+ const string& name, LocTy name_loc) {
+ Shape shape;
+ HloOpcode opcode;
+ std::vector<HloInstruction*> operands;
+
+ if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ return false;
+ }
+
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
@@ -2146,7 +2200,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
}
}
if (!ParseName(&name)) {
- return false;
+ // When parsing a single instruction (as opposed to a whole module), an
+ // HLO may have one or more operands with a shape but no name:
+ //
+ // foo = add(f32[10], f32[10])
+ //
+ // create_missing_instruction_ is always non-null when parsing a single
+ // instruction, and is responsible for creating kParameter instructions
+ // for these operands.
+ if (shape.has_value() && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ name = "";
+ } else {
+ return false;
+ }
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
@@ -2299,9 +2366,17 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kHloComputation: {
- HloComputation* result;
- if (!ParseComputationName(&result)) {
- return false;
+ HloComputation* result = nullptr;
+ if (lexer_.GetKind() == TokKind::kLbrace) {
+ // This means it is a nested computation.
+ if (!ParseInstructionList(&result, /*computation_name=*/"_")) {
+ return false;
+ }
+ } else {
+ // This means it is a computation name.
+ if (!ParseComputationName(&result)) {
+ return false;
+ }
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
@@ -3134,7 +3209,7 @@ bool HloParser::EatIfPresent(TokKind kind) {
bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc) {
- auto result = instruction_pool_.insert({name, {instruction, name_loc}});
+ auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
@@ -3204,36 +3279,51 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
return padding_config;
}
-Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name) {
- TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+Status HloParser::ParseSingleInstruction(HloModule* module) {
+ TF_RET_CHECK(create_missing_instruction_ == nullptr);
+ TF_RET_CHECK(scoped_name_tables_.empty());
+ HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
- missing_instruction_hook_ =
- [this, builder, &parameter_count](
- string name,
- const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
- if (!shape.has_value()) {
- Error(lexer_.GetLoc(),
- StrCat("Operand ", name,
- " had no shape in HLO text; cannot create parameter for "
- "single-instruction module."));
- return nullptr;
- }
- HloInstruction* parameter = builder->AddInstruction(
- HloInstruction::CreateParameter(parameter_count++, *shape, name));
- instruction_pool_[name] = {parameter, lexer_.GetLoc()};
- return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ create_missing_instruction_ =
+ [this, &builder, &parameter_count](
+ const string& name,
+ const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
+ string new_name = name.empty() ? StrCat("_", parameter_count) : name;
+ HloInstruction* parameter = builder.AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, shape, new_name));
+ current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
// Prime the lexer.
lexer_.Lex();
// Parse the instruction with the registered hook.
- if (!ParseInstruction(builder, root_name)) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ Scope scope(&scoped_name_tables_);
+ if (CanBeShape()) {
+ // This means that the instruction's left-hand side is probably omitted,
+ // e.g.
+ //
+ // f32[10] fusion(...), calls={...}
+ if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ } else {
+ // This means that the instruction's left-hand side might exist, e.g.
+ //
+ // foo = f32[10] fusion(...), calls={...}
+ string root_name;
+ if (!ParseInstruction(&builder, &root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ }
+
+ module->AddEntryComputation(builder.Build());
+ for (auto& comp : computations_) {
+ module->AddEmbeddedComputation(std::move(comp));
}
return Status::OK();
}
@@ -3271,12 +3361,8 @@ Status ParseHloString(absl::string_view str, HloModule* module) {
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
HloParser parser(str);
- auto builder = absl::make_unique<HloComputation::Builder>(string(name));
- string root_name;
- TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
- std::unique_ptr<HloComputation> computation = builder->Build();
auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
- module->AddEntryComputation(std::move(computation));
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get()));
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 3696035514..97d6f0117e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -40,8 +40,9 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
// point to an empty module (no computations).
Status ParseHloString(absl::string_view str, HloModule* module);
-// Parses the text for a single HLO operation into an HLO module with a function
-// that runs that operation (with the same parameters) as its entry computation.
+// Parses the text for a single HLO instruction into an HLO module with an
+// entry computation that runs that instruction (with the same parameters) as
+// its root instruction.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index dd4ee780f0..d10acf3814 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1763,6 +1763,25 @@ ENTRY entry {
"was parsing 8:39: error: instruction does not exist: aparam");
}
+TEST_F(HloParserTest, SameNameDiffComputations) {
+ const string original = R"(HloModule same_names:
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT result = f32[] add(p0, p1)
+}
+
+ENTRY ReduceR3ToR2 {
+ p0 = f32[8,16,256]{2,1,0} parameter(0)
+ p1 = f32[] constant(0)
+ ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
+ ASSERT_NE(module->entry_computation(), nullptr);
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
TEST_F(HloParserTest, ParseSharding) {
const string original = "{maximal device=42}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
@@ -1823,14 +1842,129 @@ TEST(HloParserSingleOpTest, SingleOp) {
op::Multiply(op::Parameter(0), op::Parameter(1)));
}
-TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
+TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
+ const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ ASSERT_TRUE(!module.status().ok());
+ LOG(INFO) << "Status: " << module.status();
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("expects '=' in instruction"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
- EXPECT_THAT(
- module.status().ToString(),
- ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("Operand had no shape in HLO text"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoNames) {
+ const string text =
+ "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, CanonicalOp) {
+ const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
+ const string text =
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+}, body=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested) {
+ const string text =
+ R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
+{
+ %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
+ %param_1 = f32[2]{0} parameter(1)
+ %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
+ ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Fusion(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("does not exist: x"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[], f32[])
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {