diff options
author | David Majnemer <majnemer@google.com> | 2018-08-21 23:56:48 -0700 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-08-22 00:03:32 -0700 |
commit | e846c2bc7dbbb5acca2d82a15b822b1445cd1e0c (patch) | |
tree | d54c263042ed561418e4e589b254904ccfd24899 /tensorflow/compiler/xla/service/hlo_parser.cc | |
parent | 1b8eb8d0a58f5b53cbae31e24d34082bc228caa8 (diff) |
[XLA] Expose a way to control dot/conv precision
This adds a field to the proto so that we may serialize it.
On TPUs, we can simulate higher precision by splitting a float32 number into several bfloat16 numbers such that their sum closely approximates the original number.
A tensor contraction operation like convolution or a dot product can be computed by forming several partial products which approximate the correct answer to a closer margin.
PiperOrigin-RevId: 209720948
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/hlo_parser.cc | 58 |
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc index b4793998ec..ede55510d3 100644 --- a/tensorflow/compiler/xla/service/hlo_parser.cc +++ b/tensorflow/compiler/xla/service/hlo_parser.cc @@ -155,6 +155,7 @@ class HloParser { kFusionKind, kDistribution, kDomain, + kPrecisionList, }; struct AttrConfig { @@ -220,6 +221,7 @@ class HloParser { bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad); bool ParseSliceRanges(SliceRanges* result); + bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result); bool ParseInt64List(const TokKind start, const TokKind end, const TokKind delim, std::vector<tensorflow::int64>* result); @@ -238,6 +240,7 @@ class HloParser { bool ParseFftType(FftType* result); bool ParseFusionKind(HloInstruction::FusionKind* result); bool ParseRandomDistribution(RandomDistribution* result); + bool ParsePrecision(PrecisionConfigProto::Precision* result); bool ParseInt64(tensorflow::int64* result); bool ParseDouble(double* result); bool ParseBool(bool* result); @@ -502,6 +505,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, attrs["backend_config"] = {/*required=*/false, AttrTy::kString, &backend_config}; + optional<std::vector<PrecisionConfigProto::Precision>> operand_precision; + attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList, + &operand_precision}; + HloInstruction* instruction; switch (opcode) { case HloOpcode::kParameter: { @@ -1366,6 +1373,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder, if (backend_config) { instruction->set_raw_backend_config_string(std::move(*backend_config)); } + if (operand_precision) { + PrecisionConfigProto precision_config; + *precision_config.mutable_operand_precision() = {operand_precision->begin(), + operand_precision->end()}; + instruction->set_precision_config(precision_config); + } return AddInstruction(name, instruction, name_loc); } // NOLINT(readability/fn_size) @@ -2343,6 +2356,16 @@ bool HloParser::ParseAttributeHelper( case AttrTy::kDomain: { return ParseDomain(static_cast<DomainData*>(attr_out_ptr)); } + case AttrTy::kPrecisionList: { + std::vector<PrecisionConfigProto::Precision> result; + if (!ParsePrecisionList(&result)) { + return false; + } + static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>( + attr_out_ptr) + ->emplace(result); + return true; + } } }(); if (!success) { @@ -2615,6 +2638,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) { return ParseToken(TokKind::kRbrace, "expects '}' to end ranges"); } +// precisionlist ::= start precision_elements end +// precision_elements +// ::= /*empty*/ +// ::= precision_val (delim precision_val)* +bool HloParser::ParsePrecisionList( + std::vector<PrecisionConfigProto::Precision>* result) { + auto parse_and_add_item = [&]() { + PrecisionConfigProto::Precision item; + if (!ParsePrecision(&item)) { + return false; + } + result->push_back(item); + return true; + }; + return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma, + parse_and_add_item); +} + // int64list ::= start int64_elements end // int64_elements // ::= /*empty*/ @@ -2941,6 +2982,23 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) { return true; } +bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) { + VLOG(1) << "ParsePrecision"; + if (lexer_.GetKind() != TokKind::kIdent) { + return TokenError("expects random distribution"); + } + string val = lexer_.GetStrVal(); + auto status_or_result = StringToPrecision(val); + if (!status_or_result.ok()) { + return TokenError( + Printf("expects precision but sees: %s, error: %s", val.c_str(), + status_or_result.status().error_message().c_str())); + } + *result = status_or_result.ValueOrDie(); + lexer_.Lex(); + return true; +} + bool HloParser::ParseInt64(tensorflow::int64* result) { VLOG(1) << "ParseInt64"; if (lexer_.GetKind() != TokKind::kInt) { |