aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar Chris Leary <leary@google.com>2018-08-21 17:41:59 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-21 17:45:23 -0700
commit274bebe3a69ff0af96c721cb9b7cfbc9cd87679d (patch)
tree72d9ce1bc954ae8dc7cf55c22beeac1e730eaeec /tensorflow/compiler/xla/service/hlo_parser.cc
parent9ed82b760567d8b4543045603546299e0dd2ae8a (diff)
[XLA] Parse a single HLO instruction's text into a module.
Converts "free variable" operands into parameters. PiperOrigin-RevId: 209691261
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc96
1 files changed, 88 insertions, 8 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 44180a881e..b4793998ec 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -33,6 +33,7 @@ namespace xla {
namespace {
+using ::absl::nullopt;
using ::absl::optional;
using ::tensorflow::StringPiece;
using ::tensorflow::str_util::Join;
@@ -66,7 +67,21 @@ class HloParser {
StatusOr<Window> ParseWindowOnly();
StatusOr<ConvolutionDimensionNumbers> ParseConvolutionDimensionNumbersOnly();
+ // Stand-alone parsing utility for a single instruction worth of text.
+ Status ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name);
+
private:
+ // Locates an instruction with the given name in the instruction_pool_ or
+ // returns nullptr.
+ //
+ // If the missing_instruction_hook_ is registered and a "shape" is provided,
+ // the hook will be called and may satisfy the request for the given
+ // instruction. This is useful when we reify parameters as they're resolved;
+ // i.e. for ParseSingleInstruction.
+ std::pair<HloInstruction*, LocTy>* FindInstruction(
+ const string& name, const optional<Shape>& shape = nullopt);
+
// ParseXXX returns false if an error occurred.
bool ParseHloModule();
bool ParseComputations();
@@ -267,6 +282,12 @@ class HloParser {
std::vector<std::unique_ptr<HloComputation>> computations_;
const HloModuleConfig config_;
std::vector<string> error_;
+
+ // Function that gets invoked when we try to resolve an instruction
+ // instruction_pool_ but fail to do so.
+ std::function<std::pair<HloInstruction*, LocTy>*(string,
+ const optional<Shape>&)>
+ missing_instruction_hook_;
};
bool HloParser::Error(LocTy loc, StringPiece msg) {
@@ -293,6 +314,17 @@ bool HloParser::Run() {
return ParseHloModule();
}
+std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
+ const string& name, const optional<Shape>& shape) {
+ std::pair<HloInstruction*, LocTy>* instr =
+ tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ // Potentially call the missing instruction hook.
+ if (instr == nullptr && missing_instruction_hook_ != nullptr) {
+ return missing_instruction_hook_(name, shape);
+ }
+ return instr;
+}
+
// ::= 'HloModule' name computations
bool HloParser::ParseHloModule() {
if (lexer_.GetKind() != TokKind::kw_HloModule) {
@@ -372,8 +404,7 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node =
- tensorflow::gtl::FindOrNull(instruction_pool_, root_name);
+ std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
// This means some instruction was marked as ROOT but we didn't find it in the
// pool, which should not happen.
if (!root_name.empty() && root_node == nullptr) {
@@ -1525,8 +1556,7 @@ bool HloParser::ParseInstructionNames(
if (!ParseName(&name)) {
return Error(loc, "expects a instruction name");
}
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = FindInstruction(name);
if (!instr) {
return TokenError(
Printf("instruction '%s' is not defined", name.c_str()));
@@ -2009,6 +2039,7 @@ bool HloParser::ParseSparseLiteralHelper(std::unique_ptr<Literal>* literal,
// ::= operand (, operand)*
// operand ::= (shape)? name
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
+ CHECK(operands != nullptr);
if (!ParseToken(TokKind::kLparen,
"expects '(' at the beginning of operands")) {
return false;
@@ -2019,9 +2050,10 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
do {
LocTy loc = lexer_.GetLoc();
string name;
+ optional<Shape> shape;
if (CanBeShape()) {
- Shape shape;
- if (!ParseShape(&shape)) {
+ shape.emplace();
+ if (!ParseShape(&shape.value())) {
return false;
}
}
@@ -2029,8 +2061,8 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
return false;
}
std::pair<HloInstruction*, LocTy>* instruction =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
- if (!instruction) {
+ FindInstruction(name, shape);
+ if (instruction == nullptr) {
return Error(loc, StrCat("instruction does not exist: ", name));
}
operands->push_back(instruction->first);
@@ -2041,6 +2073,7 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands,
const int expected_size) {
+ CHECK(operands != nullptr);
LocTy loc = lexer_.GetLoc();
if (!ParseOperands(operands)) {
return false;
@@ -3029,6 +3062,40 @@ HloParser::ParseConvolutionDimensionNumbersOnly() {
return dnums;
}
+Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
+ string* root_name) {
+ TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+
+ // The missing instruction hook we register creates the shaped instruction on
+ // the fly as a parameter and returns it.
+ int64 parameter_count = 0;
+ missing_instruction_hook_ =
+ [this, builder, &parameter_count](
+ string name,
+ const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ StrCat("Operand ", name,
+ " had no shape in HLO text; cannot create parameter for "
+ "single-instruction module."));
+ return nullptr;
+ }
+ HloInstruction* parameter = builder->AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, *shape, name));
+ instruction_pool_[name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ };
+
+ // Prime the lexer.
+ lexer_.Lex();
+
+ // Parse the instruction with the registered hook.
+ if (!ParseInstruction(builder, root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError().c_str());
+ }
+ return Status::OK();
+}
+
} // namespace
StatusOr<std::unique_ptr<HloModule>> ParseHloString(
@@ -3046,6 +3113,19 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
return ParseHloString(str, config);
}
+StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
+ tensorflow::StringPiece str, tensorflow::StringPiece name) {
+ HloModuleConfig config;
+ HloParser parser(str, config);
+ auto builder = absl::make_unique<HloComputation::Builder>(name.ToString());
+ string root_name;
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
+ std::unique_ptr<HloComputation> computation = builder->Build();
+ auto module = absl::make_unique<HloModule>(name.ToString(), config);
+ module->AddEntryComputation(std::move(computation));
+ return std::move(module);
+}
+
StatusOr<HloSharding> ParseSharding(tensorflow::StringPiece str) {
HloModuleConfig config;
HloParser parser(str, config);