aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2018-10-02 15:08:52 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-02 15:20:19 -0700
commitbb84d5d5e309204110315f7d0ff8ca0dbb022dd2 (patch)
tree351ff0255434d39238315db811581cde201380c2 /tensorflow/compiler/xla/service/hlo_parser.cc
parentcfec3aa38db1d2b70045e7b89d82fae87c3fec02 (diff)
[XLA] Support parsing the canonical format of HLO text.
Also stop truncating operands in the canonical format. PiperOrigin-RevId: 215466465
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc276
1 files changed, 181 insertions, 95 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 25b70740e3..5a125b4c08 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -80,17 +80,23 @@ class HloParser {
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
// Stand-alone parsing utility for a single instruction worth of text.
- Status ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name);
+ Status ParseSingleInstruction(HloModule* module);
private:
- // Locates an instruction with the given name in the instruction_pool_ or
+ using InstrNameTable =
+ std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
+
+ // Returns the map from the instruction name to the instruction itself and its
+ // location in the current scope.
+ InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
+
+ // Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
- // If the missing_instruction_hook_ is registered and a "shape" is provided,
- // the hook will be called and may satisfy the request for the given
- // instruction. This is useful when we reify parameters as they're resolved;
- // i.e. for ParseSingleInstruction.
+ // When the name is not found or name is empty, if create_missing_instruction_
+ // hook is registered and a "shape" is provided, the hook will be called to
+ // create an instruction. This is useful when we reify parameters as they're
+ // resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
@@ -98,9 +104,11 @@ class HloParser {
bool ParseHloModule(HloModule* module);
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
- bool ParseInstructionList(HloComputation::Builder* builder,
- string* root_name);
+ bool ParseInstructionList(HloComputation** computation,
+ const string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name,
+ LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
@@ -281,23 +289,47 @@ class HloParser {
bool AddComputation(const string& name, HloComputation* computation,
LocTy name_loc);
- // The map from the instruction/computation name to the
- // instruction/computation itself and it's location. This does not own the
- // pointers.
- std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
- instruction_pool_;
+ HloLexer lexer_;
+
+ // A stack for the instruction names. The top of the stack stores the
+ // instruction name table for the current scope.
+ //
+ // A instruction's name is unique among its scope (i.e. its parent
+ // computation), but it's not necessarily unique among all computations in the
+ // module. When there are multiple levels of nested computations, the same
+ // name could appear in both an outer computation and an inner computation. So
+ // we need a stack to make sure a name is only visible within its scope,
+ std::vector<InstrNameTable> scoped_name_tables_;
+
+ // A helper class which pushes and pops to an InstrNameTable stack via RAII.
+ class Scope {
+ public:
+ explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
+ : scoped_name_tables_(scoped_name_tables) {
+ scoped_name_tables_->emplace_back();
+ }
+ ~Scope() { scoped_name_tables_->pop_back(); }
+
+ private:
+ std::vector<InstrNameTable>* scoped_name_tables_;
+ };
+
+ // Map from the computation name to the computation itself and its location.
std::unordered_map<string, std::pair<HloComputation*, LocTy>>
computation_pool_;
- HloLexer lexer_;
std::vector<std::unique_ptr<HloComputation>> computations_;
std::vector<string> error_;
- // Function that gets invoked when we try to resolve an instruction
- // instruction_pool_ but fail to do so.
- std::function<std::pair<HloInstruction*, LocTy>*(string,
- const optional<Shape>&)>
- missing_instruction_hook_;
+ // When an operand name cannot be resolved, this function is called to create
+ // a parameter instruction with the given name and shape. It registers the
+ // name, instruction, and a placeholder location in the name table. It returns
+ // the newly-created instruction and the placeholder location. If `name` is
+ // empty, this should create the parameter with a generated name. This is
+ // supposed to be set and used only in ParseSingleInstruction.
+ std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
+ const Shape& shape)>
+ create_missing_instruction_;
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
@@ -351,11 +383,21 @@ bool HloParser::Run(HloModule* module) {
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
const string& name, const optional<Shape>& shape) {
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = nullptr;
+ if (!name.empty()) {
+ instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
+ }
+
// Potentially call the missing instruction hook.
- if (instr == nullptr && missing_instruction_hook_ != nullptr) {
- return missing_instruction_hook_(name, shape);
+ if (instr == nullptr && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ "Operand had no shape in HLO text; cannot create parameter for "
+ "single-instruction module.");
+ return nullptr;
+ }
+ return create_missing_instruction_(name, *shape);
}
return instr;
}
@@ -439,7 +481,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -447,40 +488,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- string root_name;
- if (!ParseInstructionList(builder.get(), &root_name)) {
+ HloComputation* computation = nullptr;
+ if (!ParseInstructionList(&computation, name)) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
- // This means some instruction was marked as ROOT but we didn't find it in the
- // pool, which should not happen.
- if (!root_name.empty() && root_node == nullptr) {
- LOG(FATAL) << "instruction " << root_name
- << " was marked as ROOT but the parser has not seen it before";
- }
-
- HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
- // Now root can be either an existing instruction or a nullptr. If it's a
- // nullptr, the implementation of Builder will set the last instruction as
- // root instruction.
- computations_.emplace_back(builder->Build(root));
- HloComputation* computation = computations_.back().get();
-
- if (!root) {
- root = computation->root_instruction();
- } else {
- CHECK_EQ(root, computation->root_instruction());
- }
-
// If param_list_to_shape was present, check compatibility.
- if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
+ if (shape_loc != nullptr &&
+ !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
- StrCat("Shape of computation ", name, ", ",
- ShapeUtil::HumanString(shape),
- ", is not compatible with that of its root instruction ",
- root_name, ", ", ShapeUtil::HumanString(root->shape())));
+ StrCat(
+ "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
+ ", is not compatible with that of its root instruction ",
+ computation->root_instruction()->name(), ", ",
+ ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
@@ -489,43 +511,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
}
*entry_computation = computation;
}
- instruction_pool_.clear();
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
-bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
- string* root_name) {
+bool HloParser::ParseInstructionList(HloComputation** computation,
+ const string& computation_name) {
+ Scope scope(&scoped_name_tables_);
+ HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
+ string root_name;
do {
- if (!ParseInstruction(builder, root_name)) {
+ if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
- return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of instruction list.");
+ if (!ParseToken(TokKind::kRbrace,
+ "expects '}' at the end of instruction list.")) {
+ return false;
+ }
+ HloInstruction* root = nullptr;
+ if (!root_name.empty()) {
+ std::pair<HloInstruction*, LocTy>* root_node =
+ tensorflow::gtl::FindOrNull(current_name_table(), root_name);
+
+ // This means some instruction was marked as ROOT but we didn't find it in
+ // the pool, which should not happen.
+ if (root_node == nullptr) {
+ LOG(FATAL) << "instruction " << root_name
+ << " was marked as ROOT but the parser has not seen it before";
+ }
+ root = root_node->first;
+ }
+
+ // Now root can be either an existing instruction or a nullptr. If it's a
+ // nullptr, the implementation of Builder will set the last instruction as
+ // the root instruction.
+ computations_.emplace_back(builder.Build(root));
+ *computation = computations_.back().get();
+ return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
- Shape shape;
- HloOpcode opcode;
- std::vector<HloInstruction*> operands;
-
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
- !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
- !ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
@@ -536,6 +577,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
*root_name = name;
}
+ return ParseInstruciontRhs(builder, name, name_loc);
+}
+
+bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
+ const string& name, LocTy name_loc) {
+ Shape shape;
+ HloOpcode opcode;
+ std::vector<HloInstruction*> operands;
+
+ if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ return false;
+ }
+
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
@@ -2146,7 +2200,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
}
}
if (!ParseName(&name)) {
- return false;
+ // When parsing a single instruction (as opposed to a whole module), an
+ // HLO may have one or more operands with a shape but no name:
+ //
+ // foo = add(f32[10], f32[10])
+ //
+ // create_missing_instruction_ is always non-null when parsing a single
+ // instruction, and is responsible for creating kParameter instructions
+ // for these operands.
+ if (shape.has_value() && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ name = "";
+ } else {
+ return false;
+ }
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
@@ -2299,9 +2366,17 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kHloComputation: {
- HloComputation* result;
- if (!ParseComputationName(&result)) {
- return false;
+ HloComputation* result = nullptr;
+ if (lexer_.GetKind() == TokKind::kLbrace) {
+ // This means it is a nested computation.
+ if (!ParseInstructionList(&result, /*computation_name=*/"_")) {
+ return false;
+ }
+ } else {
+ // This means it is a computation name.
+ if (!ParseComputationName(&result)) {
+ return false;
+ }
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
@@ -3134,7 +3209,7 @@ bool HloParser::EatIfPresent(TokKind kind) {
bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc) {
- auto result = instruction_pool_.insert({name, {instruction, name_loc}});
+ auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
@@ -3204,36 +3279,51 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
return padding_config;
}
-Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name) {
- TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+Status HloParser::ParseSingleInstruction(HloModule* module) {
+ TF_RET_CHECK(create_missing_instruction_ == nullptr);
+ TF_RET_CHECK(scoped_name_tables_.empty());
+ HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
- missing_instruction_hook_ =
- [this, builder, &parameter_count](
- string name,
- const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
- if (!shape.has_value()) {
- Error(lexer_.GetLoc(),
- StrCat("Operand ", name,
- " had no shape in HLO text; cannot create parameter for "
- "single-instruction module."));
- return nullptr;
- }
- HloInstruction* parameter = builder->AddInstruction(
- HloInstruction::CreateParameter(parameter_count++, *shape, name));
- instruction_pool_[name] = {parameter, lexer_.GetLoc()};
- return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ create_missing_instruction_ =
+ [this, &builder, &parameter_count](
+ const string& name,
+ const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
+ string new_name = name.empty() ? StrCat("_", parameter_count) : name;
+ HloInstruction* parameter = builder.AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, shape, new_name));
+ current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
// Prime the lexer.
lexer_.Lex();
// Parse the instruction with the registered hook.
- if (!ParseInstruction(builder, root_name)) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ Scope scope(&scoped_name_tables_);
+ if (CanBeShape()) {
+ // This means that the instruction's left-hand side is probably omitted,
+ // e.g.
+ //
+ // f32[10] fusion(...), calls={...}
+ if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ } else {
+ // This means that the instruction's left-hand side might exist, e.g.
+ //
+ // foo = f32[10] fusion(...), calls={...}
+ string root_name;
+ if (!ParseInstruction(&builder, &root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ }
+
+ module->AddEntryComputation(builder.Build());
+ for (auto& comp : computations_) {
+ module->AddEmbeddedComputation(std::move(comp));
}
return Status::OK();
}
@@ -3271,12 +3361,8 @@ Status ParseHloString(absl::string_view str, HloModule* module) {
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
HloParser parser(str);
- auto builder = absl::make_unique<HloComputation::Builder>(string(name));
- string root_name;
- TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
- std::unique_ptr<HloComputation> computation = builder->Build();
auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
- module->AddEntryComputation(std::move(computation));
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get()));
return std::move(module);
}