diff options
author | 2017-11-10 16:46:27 -0800 | |
---|---|---|
committer | 2017-11-10 17:15:57 -0800 | |
commit | 61aebf140e12e2ad834dc94a83f23fc574c79340 (patch) | |
tree | 1e88cca51524a349e82a47b03c55d823a18a5902 | |
parent | 08114b6093f7c461483e2f466af49ed55689708c (diff) |
Hlo parser: support metadata.
Also give metadata it's own format.
PiperOrigin-RevId: 175356154
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 21 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_lexer.cc | 24 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_lexer.h | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 344 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc | 16 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_token.h | 1 |
7 files changed, 299 insertions, 111 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index 674d3e3836..1e83c69b50 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -1826,7 +1826,7 @@ string HloInstruction::ToString(bool compact_operands, bool include_metadata, if (include_metadata && (!metadata_.op_type().empty() || !metadata_.op_name().empty() || !metadata_.source_file().empty())) { - StrAppend(&result, " # metadata=", metadata_.ShortDebugString()); + StrAppend(&result, ", metadata={", xla::OpMetadataToString(metadata_), "}"); } return result; } @@ -2910,6 +2910,25 @@ string PaddingConfigToString(const PaddingConfig& padding) { }); } +string OpMetadataToString(const OpMetadata& metadata) { + std::vector<string> result; + using tensorflow::str_util::CEscape; + if (!metadata.op_type().empty()) { + result.push_back(StrCat("op_type=\"", CEscape(metadata.op_type()), "\"")); + } + if (!metadata.op_name().empty()) { + result.push_back(StrCat("op_name=\"", CEscape(metadata.op_name()), "\"")); + } + if (!metadata.source_file().empty()) { + result.push_back( + StrCat("source_file=\"", CEscape(metadata.source_file()), "\"")); + } + if (metadata.source_line() != 0) { + result.push_back(StrCat("source_line=", metadata.source_line())); + } + return Join(result, " "); +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index d174f05aa6..438d8bb35b 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1239,7 +1239,9 @@ string ToString(HloInstruction::FusionKind kind); StatusOr<HloInstruction::FusionKind> StringToFusionKind( const string& kind_name); +// Custom stringification functions for protos that live inside HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); +string OpMetadataToString(const OpMetadata& metadata); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index b5befbf58b..098879155a 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -22,6 +22,7 @@ limitations under the License. #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/optional.h" #include "tensorflow/core/lib/strings/numbers.h" +#include "tensorflow/core/lib/strings/str_util.h" #include "tensorflow/core/platform/regexp.h" namespace xla { @@ -145,6 +146,8 @@ TokKind HloLexer::LexToken() { return TokKind::kRparen; case '/': return LexComment(); + case '"': + return LexString(); } } } @@ -340,6 +343,25 @@ TokKind HloLexer::LexComment() { return TokKind::kError; } +// Lexes quoted string with escaping characters. If matched, the quoted string +// will be unescaped and stored to str_val_. +TokKind HloLexer::LexString() { + auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); + static LazyRE2 escaping_pattern = {R"("([^"\\]|\\.)*")"}; + if (RE2::Consume(&consumable, *escaping_pattern)) { + current_ptr_ = consumable.begin(); + StringPiece raw = + StringPieceFromPointers(token_start_ + 1, current_ptr_ - 1); + string error; + if (!tensorflow::str_util::CUnescape(raw, &str_val_, &error)) { + LOG(ERROR) << "Failed unescaping string: " << raw << ". error: " << error; + return TokKind::kError; + } + return TokKind::kString; + } + return TokKind::kError; +} + string TokKindToString(TokKind kind) { switch (kind) { case TokKind::kEof: @@ -398,6 +420,8 @@ string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kString: + return "kString"; case TokKind::kShape: return "kShape"; case TokKind::kOpcode: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 79c4f271a1..2236c26619 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -46,6 +46,7 @@ class HloLexer { case TokKind::kDimLabels: case TokKind::kDxD: case TokKind::kPad: + case TokKind::kString: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -98,6 +99,7 @@ class HloLexer { TokKind LexConstant(); TokKind LexNumberOrPattern(); TokKind LexComment(); + TokKind LexString(); const tensorflow::StringPiece buf_; const char* current_ptr_; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index fed0492a54..ac7d9ff482 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -91,7 +91,9 @@ class HloParser { // Types of attributes. enum class AttrTy { kInt64, + kInt32, kFloat, + kString, kBracedInt64List, kHloComputation, kWindow, @@ -100,6 +102,7 @@ class HloParser { kInstructionList, kSliceRanges, kPaddingConfig, + kMetadata, }; struct AttrConfig { @@ -108,6 +111,8 @@ class HloParser { void* result; // where to store the parsed result. }; + // attributes ::= (',' attribute)* + // // Parses attributes given names and configs of the attributes. Each parsed // result is passed back through the result pointer in corresponding // AttrConfig. Note that the result pointer must point to a optional<T> typed @@ -121,7 +126,7 @@ class HloParser { // attrs["foo"] = {/*required=*/false, AttrTy::kInt64, &foo}; // optional<Window> bar; // attrs["bar"] = {/*required=*/true, AttrTy::kWindow, &bar}; - // if (!ParseAttribute(attrs)) { + // if (!ParseAttributes(attrs)) { // return false; // Do not use 'foo' 'bar' if failed. // } // // Do something with 'bar'. @@ -129,6 +134,18 @@ class HloParser { // bool ParseAttributes(const std::unordered_map<string, AttrConfig>& attrs); + // sub_attributes ::= '{' (','? attribute)* '}' + // + // Usage is the same as ParseAttributes. See immediately above. + bool ParseSubAttributes(const std::unordered_map<string, AttrConfig>& attrs); + + // Parses one attribute. If it has already been seen, return error. Returns + // true and adds to seen_attrs on success. + // + // Do not call this except in ParseAttributes or ParseSubAttributes. + bool ParseAttributeHelper(const std::unordered_map<string, AttrConfig>& attrs, + std::unordered_set<string>* seen_attrs); + // Parses a name and finds the corresponding hlo computation. bool ParseComputationName(HloComputation** value); // Parses a list of names and finds the corresponding hlo instructions. @@ -136,6 +153,7 @@ class HloParser { bool ParseWindow(Window* window); bool ParseConvolutionDimensionNumbers(ConvolutionDimensionNumbers* dnums); bool ParsePaddingConfig(PaddingConfig* padding); + bool ParseMetadata(OpMetadata* metadata); bool ParseSharding(OpSharding* sharding); bool ParseSingleSharding(OpSharding* sharding, bool lbrace_pre_lexed); @@ -151,6 +169,7 @@ class HloParser { bool ParseParamList(); bool ParseName(string* result); bool ParseAttributeName(string* result); + bool ParseString(string* result); bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); bool ParseInt64(int64* result); @@ -303,6 +322,8 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, optional<std::vector<HloInstruction*>> predecessors; attrs["control-predecessors"] = {/*required=*/false, AttrTy::kInstructionList, &predecessors}; + optional<OpMetadata> metadata; + attrs["metadata"] = {/*required=*/false, AttrTy::kMetadata, &metadata}; HloInstruction* instruction; switch (opcode) { @@ -766,6 +787,9 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, } } } + if (metadata) { + instruction->set_metadata(*metadata); + } return AddInstruction(name, instruction); } @@ -1284,129 +1308,194 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands, return true; } +// sub_attributes ::= '{' (','? attribute)* '}' +bool HloParser::ParseSubAttributes( + const std::unordered_map<string, AttrConfig>& attrs) { + if (!ParseToken(TokKind::kLbrace, "expects '{' to start sub attributes")) { + return false; + } + std::unordered_set<string> seen_attrs; + if (lexer_.GetKind() == TokKind::kRbrace) { + // empty + } else { + do { + EatIfPresent(TokKind::kComma); + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } while (lexer_.GetKind() != TokKind::kRbrace); + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return TokenError(Printf("sub-attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return ParseToken(TokKind::kRbrace, "expects '}' to end sub attributes"); +} + +// attributes ::= (',' attribute)* bool HloParser::ParseAttributes( const std::unordered_map<string, AttrConfig>& attrs) { std::unordered_set<string> seen_attrs; while (EatIfPresent(TokKind::kComma)) { - string name; - if (!ParseAttributeName(&name)) { - return TokenError("error parsing attributes"); - } - VLOG(1) << "Parsing attribute " << name; - if (!seen_attrs.insert(name).second) { - return TokenError(Printf("attribute %s already exists", name.c_str())); - } - auto attr_it = attrs.find(name); - if (attr_it == attrs.end()) { - return TokenError(Printf("unexpected attribute %s", name.c_str())); - } - AttrTy attr_type = attr_it->second.attr_type; - void* attr_out_ptr = attr_it->second.result; - bool success = [&] { - switch (attr_type) { - case AttrTy::kInt64: { - int64 result; - if (!ParseInt64(&result)) { - return false; - } - static_cast<optional<int64>*>(attr_out_ptr)->emplace(result); - return true; + if (!ParseAttributeHelper(attrs, &seen_attrs)) { + return false; + } + } + // Check that all required attrs were seen. + for (const auto& attr_it : attrs) { + if (attr_it.second.required && + seen_attrs.find(attr_it.first) == seen_attrs.end()) { + return TokenError(Printf("attribute %s is expected but not seen", + attr_it.first.c_str())); + } + } + return true; +} + +bool HloParser::ParseAttributeHelper( + const std::unordered_map<string, AttrConfig>& attrs, + std::unordered_set<string>* seen_attrs) { + string name; + if (!ParseAttributeName(&name)) { + return TokenError("error parsing attributes"); + } + VLOG(1) << "Parsing attribute " << name; + if (!seen_attrs->insert(name).second) { + return TokenError(Printf("attribute %s already exists", name.c_str())); + } + auto attr_it = attrs.find(name); + if (attr_it == attrs.end()) { + return TokenError(Printf("unexpected attribute %s", name.c_str())); + } + AttrTy attr_type = attr_it->second.attr_type; + void* attr_out_ptr = attr_it->second.result; + bool success = [&] { + switch (attr_type) { + case AttrTy::kInt64: { + int64 result; + if (!ParseInt64(&result)) { + return false; } - case AttrTy::kFloat: { - double result; - if (!ParseDouble(&result)) { - return false; - } - if (result > std::numeric_limits<float>::max() || - result < std::numeric_limits<float>::lowest()) { - return TokenError("value out of range for float"); - } - static_cast<optional<float>*>(attr_out_ptr) - ->emplace(static_cast<float>(result)); - return true; + static_cast<optional<int64>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kInt32: { + int64 result; + if (!ParseInt64(&result)) { + return false; } - case AttrTy::kHloComputation: { - HloComputation* result; - if (!ParseComputationName(&result)) { - return false; - } - static_cast<optional<HloComputation*>*>(attr_out_ptr) - ->emplace(result); - return true; + if (result != static_cast<int32>(result)) { + return TokenError("value out of range for int32"); } - case AttrTy::kWindow: { - Window result; - if (!ParseWindow(&result)) { - return false; - } - static_cast<optional<Window>*>(attr_out_ptr)->emplace(result); - return true; + static_cast<optional<int32>*>(attr_out_ptr) + ->emplace(static_cast<int32>(result)); + return true; + } + case AttrTy::kFloat: { + double result; + if (!ParseDouble(&result)) { + return false; } - case AttrTy::kConvolutionDimensionNumbers: { - ConvolutionDimensionNumbers result; - if (!ParseConvolutionDimensionNumbers(&result)) { - return false; - } - static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr) - ->emplace(result); - return true; + if (result > std::numeric_limits<float>::max() || + result < std::numeric_limits<float>::lowest()) { + return TokenError("value out of range for float"); } - case AttrTy::kSharding: { - OpSharding sharding; - if (!ParseSharding(&sharding)) { - return false; - } - static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding); - return true; + static_cast<optional<float>*>(attr_out_ptr) + ->emplace(static_cast<float>(result)); + return true; + } + case AttrTy::kHloComputation: { + HloComputation* result; + if (!ParseComputationName(&result)) { + return false; } - case AttrTy::kInstructionList: { - std::vector<HloInstruction*> result; - if (!ParseInstructionNames(&result)) { - return false; - } - static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr) - ->emplace(result); - return true; + static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kWindow: { + Window result; + if (!ParseWindow(&result)) { + return false; } - case AttrTy::kBracedInt64List: { - std::vector<int64> result; - if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, - TokKind::kComma, &result)) { - return false; - } - static_cast<optional<std::vector<int64>>*>(attr_out_ptr) - ->emplace(result); - return true; + static_cast<optional<Window>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kConvolutionDimensionNumbers: { + ConvolutionDimensionNumbers result; + if (!ParseConvolutionDimensionNumbers(&result)) { + return false; } - case AttrTy::kSliceRanges: { - SliceRanges result; - if (!ParseSliceRanges(&result)) { - return false; - } - static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result); - return true; + static_cast<optional<ConvolutionDimensionNumbers>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSharding: { + OpSharding sharding; + if (!ParseSharding(&sharding)) { + return false; } - case AttrTy::kPaddingConfig: { - PaddingConfig result; - if (!ParsePaddingConfig(&result)) { - return false; - } - static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result); - return true; + static_cast<optional<OpSharding>*>(attr_out_ptr)->emplace(sharding); + return true; + } + case AttrTy::kInstructionList: { + std::vector<HloInstruction*> result; + if (!ParseInstructionNames(&result)) { + return false; } + static_cast<optional<std::vector<HloInstruction*>>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kBracedInt64List: { + std::vector<int64> result; + if (!ParseInt64List(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + &result)) { + return false; + } + static_cast<optional<std::vector<int64>>*>(attr_out_ptr) + ->emplace(result); + return true; + } + case AttrTy::kSliceRanges: { + SliceRanges result; + if (!ParseSliceRanges(&result)) { + return false; + } + static_cast<optional<SliceRanges>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kPaddingConfig: { + PaddingConfig result; + if (!ParsePaddingConfig(&result)) { + return false; + } + static_cast<optional<PaddingConfig>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kString: { + string result; + if (!ParseString(&result)) { + return false; + } + static_cast<optional<string>*>(attr_out_ptr)->emplace(result); + return true; + } + case AttrTy::kMetadata: { + OpMetadata result; + if (!ParseMetadata(&result)) { + return false; + } + static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result); + return true; } - }(); - if (!success) { - return TokenError(Printf("error parsing attribute %s", name.c_str())); - } - } - // Check that all required attrs were seen. - for (const auto& attr_it : attrs) { - if (attr_it.second.required && - seen_attrs.find(attr_it.first) == seen_attrs.end()) { - return TokenError(Printf("attribute %s is expected but not seen", - attr_it.first.c_str())); } + }(); + if (!success) { + return TokenError(Printf("error parsing attribute %s", name.c_str())); } return true; } @@ -1763,6 +1852,16 @@ bool HloParser::ParseAttributeName(string* result) { return true; } +bool HloParser::ParseString(string* result) { + VLOG(1) << "ParseString"; + if (lexer_.GetKind() != TokKind::kString) { + return TokenError("expects string"); + } + *result = lexer_.GetStrVal(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseDxD(const string& name, std::vector<int64>* result) { if (!result->empty()) { return TokenError( @@ -1839,6 +1938,35 @@ bool HloParser::ParsePaddingConfig(PaddingConfig* padding) { return true; } +// '{' metadata_string '}' +bool HloParser::ParseMetadata(OpMetadata* metadata) { + std::unordered_map<string, AttrConfig> attrs; + optional<string> op_type; + optional<string> op_name; + optional<string> source_file; + optional<int32> source_line; + attrs["op_type"] = {/*required=*/false, AttrTy::kString, &op_type}; + attrs["op_name"] = {/*required=*/false, AttrTy::kString, &op_name}; + attrs["source_file"] = {/*required=*/false, AttrTy::kString, &source_file}; + attrs["source_line"] = {/*required=*/false, AttrTy::kInt32, &source_line}; + if (!ParseSubAttributes(attrs)) { + return false; + } + if (op_type) { + metadata->set_op_type(*op_type); + } + if (op_name) { + metadata->set_op_name(*op_name); + } + if (source_file) { + metadata->set_source_file(*source_file); + } + if (source_line) { + metadata->set_source_line(*source_line); + } + return true; +} + bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; if (lexer_.GetKind() != TokKind::kOpcode) { diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index d19c6e1877..bed912d921 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -65,7 +65,7 @@ ENTRY %axpy.v5 (alpha: f32[], x: f32[2,4], y: f32[2,4]) -> f32[2,4] { R"(HloModule constant_pred_module: ENTRY %constant_pred () -> pred[] { - ROOT %constant = pred[] constant(true) + ROOT %constant = pred[] constant(true), metadata={op_type="const" op_name="\"it\'s not a problem\n" source_file="path/to/test.cc" source_line=68} } )" @@ -83,7 +83,8 @@ ENTRY %constant_s32 () -> s32[] { }, // f32 constant, but the value is not a decimal { -"ConstantF32", R"(HloModule ConstantF32_module: +"ConstantF32", +R"(HloModule ConstantF32_module: ENTRY %ConstantF32.v4 () -> f32[] { ROOT %constant = f32[] constant(42) @@ -841,6 +842,17 @@ ENTRY %Convolve1D1Window_0.v3 (input: f32[1,2,1], filter: f32[1,1,1]) -> f32[1,2 "expects padding_low and padding_high separated by '_'"); } +TEST_F(HloParserTest, CommaBetweenSubAttributes) { + const string original = R"(HloModule test_comma_module: + +ENTRY %test_comma.v4 () -> f32[] { + ROOT %constant = f32[] constant(-4.2), metadata={source_line=5, op_type="::const"} +} + +)"; + TF_EXPECT_OK(Parse(original).status()); +} + } // namespace } // namespace tools } // namespace xla diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 9afd2fac23..78a72837ca 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -60,6 +60,7 @@ enum class TokKind { kDimLabels, // [0-9bf]+_[0-9io]+->[0-9bf]+ kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} kOpcode, // add kInt, // 42 |