diff options
author | 2017-11-30 21:08:05 -0800 | |
---|---|---|
committer | 2017-11-30 21:11:33 -0800 | |
commit | 6eec9c2ea33f3b86012cb0ea2aeb9e49e65bc716 (patch) | |
tree | d02d0e5149055a01318278876967ef3d04796de6 | |
parent | 1ec61fafe13e5edce6e45d5a67e960efb9df618a (diff) |
[XLA] Hlo parser: support rng and reduce-precision. Also simplify the lexer by regarding several things as identifier.
PiperOrigin-RevId: 177548483
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.cc | 30 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_instruction.h | 5 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/BUILD | 2 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_lexer.cc | 32 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_lexer.h | 14 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser.cc | 81 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc | 25 | ||||
-rw-r--r-- | tensorflow/compiler/xla/tools/parser/hlo_token.h | 6 |
8 files changed, 149 insertions, 46 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc index b4bac18bcd..45825c7c76 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.cc +++ b/tensorflow/compiler/xla/service/hlo_instruction.cc @@ -2060,6 +2060,14 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const { extra.push_back( StrCat("outfeed_config=\"", CEscape(outfeed_config_), "\"")); } + if (opcode() == HloOpcode::kRng) { + extra.push_back( + StrCat("distribution=", RandomDistributionToString(distribution_))); + } + if (opcode() == HloOpcode::kReducePrecision) { + extra.push_back(StrCat("exponent_bits=", exponent_bits_)); + extra.push_back(StrCat("mantissa_bits=", mantissa_bits_)); + } return extra; } @@ -3029,6 +3037,28 @@ string OpMetadataToString(const OpMetadata& metadata) { return Join(result, " "); } +string RandomDistributionToString(const RandomDistribution& distribution) { + return tensorflow::str_util::Lowercase(RandomDistribution_Name(distribution)); +} + +StatusOr<RandomDistribution> StringToRandomDistribution(const string& name) { + static std::unordered_map<string, RandomDistribution>* map = [] { + static auto* map = new std::unordered_map<string, RandomDistribution>; + for (int i = 0; i < RandomDistribution_ARRAYSIZE; i++) { + if (RandomDistribution_IsValid(i)) { + auto value = static_cast<RandomDistribution>(i); + (*map)[RandomDistributionToString(value)] = value; + } + } + return map; + }(); + auto found = map->find(tensorflow::str_util::Lowercase(name)); + if (found == map->end()) { + return InvalidArgument("Unknown distribution"); + } + return found->second; +} + std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind) { return os << ToString(kind); } diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h index 768c027a42..088902e2a7 100644 --- a/tensorflow/compiler/xla/service/hlo_instruction.h +++ b/tensorflow/compiler/xla/service/hlo_instruction.h @@ -1285,9 +1285,12 @@ string ToString(HloInstruction::FusionKind kind); StatusOr<HloInstruction::FusionKind> StringToFusionKind( const string& kind_name); -// Custom stringification functions for protos that live inside HloInstruction. +// Custom (de)stringification functions for protos that live inside +// HloInstruction. string PaddingConfigToString(const PaddingConfig& padding); string OpMetadataToString(const OpMetadata& metadata); +string RandomDistributionToString(const RandomDistribution& distribution); +StatusOr<RandomDistribution> StringToRandomDistribution(const string& name); std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind); diff --git a/tensorflow/compiler/xla/tools/parser/BUILD b/tensorflow/compiler/xla/tools/parser/BUILD index ce936af6c3..97aacf6b39 100644 --- a/tensorflow/compiler/xla/tools/parser/BUILD +++ b/tensorflow/compiler/xla/tools/parser/BUILD @@ -34,9 +34,9 @@ cc_library( deps = [ "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:statusor", + "//tensorflow/compiler/xla:types", "//tensorflow/compiler/xla:util", "//tensorflow/compiler/xla:xla_data_proto", - "//tensorflow/compiler/xla/service:hlo", "//tensorflow/core:lib", "//tensorflow/core:regexp_internal", ], diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc index 56744440db..04247594ed 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.cc @@ -17,7 +17,6 @@ limitations under the License. #include <unordered_map> -#include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/statusor.h" #include "tensorflow/compiler/xla/util.h" @@ -153,15 +152,15 @@ TokKind HloLexer::LexToken() { } } -// Lex a shape, name, keyword, opcode, attribute name, or the dim labels -// pattern. +// Lex a shape, name, keyword, attribute name, the dim labels pattern, and +// other identifiers. // // shape ::= ([a-zA-Z0-9_]*[0-9]*)\[([0-9,]*)\](?:\s*{([0-9,]*)})? // name ::= [a-zA-Z_][a-zA-Z0-9_.-]*: // keyword ::= HloModule, ENTRY, ... -// opcode ::= add, greater-than, ... // attribute_name ::= condition, body, dimensions, ... // dim_labels_pattern ::= [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} +// identifiers ::= other cases that match [a-zA-Z_][a-zA-Z0-9_.-]* TokKind HloLexer::LexIdentifier() { { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); @@ -220,20 +219,6 @@ TokKind HloLexer::LexIdentifier() { #undef KEYWORD - // See if this is an opcode. - auto opcode = StringToHloOpcode(identifier.ToString()); - if (opcode.ok()) { - opcode_val_ = opcode.ValueOrDie(); - return TokKind::kOpcode; - } - - // See if this is an fusion kind. - auto kind = xla::StringToFusionKind(identifier.ToString()); - if (kind.ok()) { - fusion_kind_val_ = kind.ValueOrDie(); - return TokKind::kFusionKind; - } - { auto consumable = RegexpStringPieceFromPointers(token_start_, buf_.end()); static LazyRE2 dim_labels_pattern = { @@ -244,8 +229,9 @@ TokKind HloLexer::LexIdentifier() { return TokKind::kDimLabels; } } - current_ptr_ = token_start_ + 1; - return TokKind::kError; + + str_val_ = identifier.ToString(); + return TokKind::kIdent; } // Lex names after a % character. @@ -428,14 +414,12 @@ string TokKindToString(TokKind kind) { return "kDxD"; case TokKind::kPad: return "kPad"; + case TokKind::kIdent: + return "kIdent"; case TokKind::kString: return "kString"; case TokKind::kShape: return "kShape"; - case TokKind::kOpcode: - return "kOpcode"; - case TokKind::kFusionKind: - return "kFusionKind"; case TokKind::kInt: return "kInt"; case TokKind::kDecimal: diff --git a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h index 5c9d1bf391..9daf6a11d3 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_lexer.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_lexer.h @@ -18,9 +18,8 @@ limitations under the License. #include <string> -#include "tensorflow/compiler/xla/service/hlo_instruction.h" -#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/tools/parser/hlo_token.h" +#include "tensorflow/compiler/xla/types.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/core/stringpiece.h" #include "tensorflow/core/platform/logging.h" @@ -48,6 +47,7 @@ class HloLexer { case TokKind::kDxD: case TokKind::kPad: case TokKind::kString: + case TokKind::kIdent: return str_val_; default: LOG(FATAL) << "This token does not have string value"; @@ -57,14 +57,6 @@ class HloLexer { CHECK(GetKind() == TokKind::kShape); return shape_val_; } - HloOpcode GetOpcodeVal() const { - CHECK(GetKind() == TokKind::kOpcode); - return opcode_val_; - } - HloInstruction::FusionKind GetFusionKindVal() const { - CHECK(GetKind() == TokKind::kFusionKind); - return fusion_kind_val_; - } int64 GetInt64Val() const { CHECK(GetKind() == TokKind::kInt); return int64_val_; @@ -114,8 +106,6 @@ class HloLexer { TokKind current_kind_; string str_val_; Shape shape_val_; - HloOpcode opcode_val_; - HloInstruction::FusionKind fusion_kind_val_; int64 int64_val_; double decimal_val_; }; diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc index 47979ec6f3..ddc1e69951 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/xla/tools/parser/hlo_parser.h" #include "tensorflow/compiler/xla/literal_util.h" +#include "tensorflow/compiler/xla/service/hlo_opcode.h" #include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/util.h" #include "tensorflow/core/lib/gtl/map_util.h" @@ -104,6 +105,7 @@ class HloParser { kPaddingConfig, kMetadata, kFusionKind, + kDistribution, }; struct AttrConfig { @@ -174,6 +176,7 @@ class HloParser { bool ParseShape(Shape* result); bool ParseOpcode(HloOpcode* result); bool ParseFusionKind(HloInstruction::FusionKind* result); + bool ParseRandomDistribution(RandomDistribution* result); bool ParseInt64(int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -816,10 +819,36 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, shape, operands[0], config ? *config : "")); break; } + case HloOpcode::kRng: { + optional<RandomDistribution> distribution; + attrs["distribution"] = {/*required=*/true, AttrTy::kDistribution, + &distribution}; + if (!ParseOperands(&operands) || !ParseAttributes(attrs)) { + return false; + } + instruction = builder->AddInstruction( + HloInstruction::CreateRng(shape, *distribution, operands)); + break; + } + case HloOpcode::kReducePrecision: { + optional<int64> exponent_bits; + optional<int64> mantissa_bits; + attrs["exponent_bits"] = {/*required=*/true, AttrTy::kInt64, + &exponent_bits}; + attrs["mantissa_bits"] = {/*required=*/true, AttrTy::kInt64, + &mantissa_bits}; + if (!ParseOperands(&operands, /*expected_size=*/1) || + !ParseAttributes(attrs)) { + return false; + } + instruction = + builder->AddInstruction(HloInstruction::CreateReducePrecision( + shape, operands[0], static_cast<int>(*exponent_bits), + static_cast<int>(*mantissa_bits))); + break; + } case HloOpcode::kConditional: case HloOpcode::kCustomCall: - case HloOpcode::kReducePrecision: - case HloOpcode::kRng: case HloOpcode::kTrace: return TokenError(StrCat("parsing not yet implemented for op: ", HloOpcodeString(opcode))); @@ -1548,6 +1577,15 @@ bool HloParser::ParseAttributeHelper( static_cast<optional<OpMetadata>*>(attr_out_ptr)->emplace(result); return true; } + case AttrTy::kDistribution: { + RandomDistribution result; + if (!ParseRandomDistribution(&result)) { + return false; + } + static_cast<optional<RandomDistribution>*>(attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2024,20 +2062,51 @@ bool HloParser::ParseMetadata(OpMetadata* metadata) { bool HloParser::ParseOpcode(HloOpcode* result) { VLOG(1) << "ParseOpcode"; - if (lexer_.GetKind() != TokKind::kOpcode) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects opcode"); } - *result = lexer_.GetOpcodeVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToHloOpcode(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects opcode but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } bool HloParser::ParseFusionKind(HloInstruction::FusionKind* result) { VLOG(1) << "ParseFusionKind"; - if (lexer_.GetKind() != TokKind::kFusionKind) { + if (lexer_.GetKind() != TokKind::kIdent) { return TokenError("expects fusion kind"); } - *result = lexer_.GetFusionKindVal(); + string val = lexer_.GetStrVal(); + auto status_or_result = StringToFusionKind(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects fusion kind but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + +bool HloParser::ParseRandomDistribution(RandomDistribution* result) { + VLOG(1) << "ParseRandomDistribution"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToRandomDistribution(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects random distribution but sees: %s, error: %s", + val.c_str(), status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); lexer_.Lex(); return true; } diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc index 90cdb87a1e..69d48d65bc 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc +++ b/tensorflow/compiler/xla/tools/parser/hlo_parser_test.cc @@ -655,6 +655,31 @@ ENTRY %InfeedToOutfeed () -> (u32[3], pred[]) { } )" +}, +// Rng +{ +"Rng", +R"(HloModule rng_module: + +ENTRY %Rng () -> f32[8] { + %constant = f32[] constant(0) + %constant.1 = f32[] constant(1) + ROOT %rng = f32[8]{0} rng(f32[] %constant, f32[] %constant.1), distribution=rng_uniform +} + +)" +}, +// Reduce precision +{ +"ReducePrevison", +R"(HloModule reduce_precision: + +ENTRY %ReducePrecision () -> f32[1] { + %constant = f32[1]{0} constant({3.14159}) + ROOT %reduce-precision = f32[1]{0} reduce-precision(f32[1]{0} %constant), exponent_bits=8, mantissa_bits=10 +} + +)" } }); // clang-format on diff --git a/tensorflow/compiler/xla/tools/parser/hlo_token.h b/tensorflow/compiler/xla/tools/parser/hlo_token.h index 07e48804d0..7928bee5c2 100644 --- a/tensorflow/compiler/xla/tools/parser/hlo_token.h +++ b/tensorflow/compiler/xla/tools/parser/hlo_token.h @@ -18,6 +18,9 @@ limitations under the License. #include <string> +#include "tensorflow/compiler/xla/types.h" +#include "tensorflow/core/platform/types.h" + namespace xla { namespace tools { @@ -60,10 +63,9 @@ enum class TokKind { kDimLabels, // [0-9bf]{2,}_[0-9io]{2,}->[0-9bf]{2,} kDxD, // [0-9]+(x[0-9]+)+ kPad, // [0-9]+_[0-9]+(_[0-9]+)?(x[0-9]+_[0-9]+(_[0-9]+)?)* + kIdent, // other identifiers kString, // "abcd\"\n" kShape, // f32[2,3]{1,0} - kOpcode, // add - kFusionKind, // kLoop, kOutput, ... kInt, // 42 kDecimal, // 4.2 }; |