aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_parser.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-21 23:56:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 00:03:32 -0700
commite846c2bc7dbbb5acca2d82a15b822b1445cd1e0c (patch)
treed54c263042ed561418e4e589b254904ccfd24899 /tensorflow/compiler/xla/service/hlo_parser.cc
parent1b8eb8d0a58f5b53cbae31e24d34082bc228caa8 (diff)
[XLA] Expose a way to control dot/conv precision
This adds a field to the proto so that we may serialize it. On TPUs, we can simulate higher precision by splitting a float32 number into several bfloat16 numbers such that their sum closely approximates the original number. A tensor contraction operation like convolution or a dot product can be computed by forming several partial products which approximate the correct answer to a closer margin. PiperOrigin-RevId: 209720948
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_parser.cc')
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc58
1 files changed, 58 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index b4793998ec..ede55510d3 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -155,6 +155,7 @@ class HloParser {
kFusionKind,
kDistribution,
kDomain,
+ kPrecisionList,
};
struct AttrConfig {
@@ -220,6 +221,7 @@ class HloParser {
bool ParseWindowPad(std::vector<std::vector<tensorflow::int64>>* pad);
bool ParseSliceRanges(SliceRanges* result);
+ bool ParsePrecisionList(std::vector<PrecisionConfigProto::Precision>* result);
bool ParseInt64List(const TokKind start, const TokKind end,
const TokKind delim,
std::vector<tensorflow::int64>* result);
@@ -238,6 +240,7 @@ class HloParser {
bool ParseFftType(FftType* result);
bool ParseFusionKind(HloInstruction::FusionKind* result);
bool ParseRandomDistribution(RandomDistribution* result);
+ bool ParsePrecision(PrecisionConfigProto::Precision* result);
bool ParseInt64(tensorflow::int64* result);
bool ParseDouble(double* result);
bool ParseBool(bool* result);
@@ -502,6 +505,10 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
attrs["backend_config"] = {/*required=*/false, AttrTy::kString,
&backend_config};
+ optional<std::vector<PrecisionConfigProto::Precision>> operand_precision;
+ attrs["operand_precision"] = {/*required=*/false, AttrTy::kPrecisionList,
+ &operand_precision};
+
HloInstruction* instruction;
switch (opcode) {
case HloOpcode::kParameter: {
@@ -1366,6 +1373,12 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
if (backend_config) {
instruction->set_raw_backend_config_string(std::move(*backend_config));
}
+ if (operand_precision) {
+ PrecisionConfigProto precision_config;
+ *precision_config.mutable_operand_precision() = {operand_precision->begin(),
+ operand_precision->end()};
+ instruction->set_precision_config(precision_config);
+ }
return AddInstruction(name, instruction, name_loc);
} // NOLINT(readability/fn_size)
@@ -2343,6 +2356,16 @@ bool HloParser::ParseAttributeHelper(
case AttrTy::kDomain: {
return ParseDomain(static_cast<DomainData*>(attr_out_ptr));
}
+ case AttrTy::kPrecisionList: {
+ std::vector<PrecisionConfigProto::Precision> result;
+ if (!ParsePrecisionList(&result)) {
+ return false;
+ }
+ static_cast<optional<std::vector<PrecisionConfigProto::Precision>>*>(
+ attr_out_ptr)
+ ->emplace(result);
+ return true;
+ }
}
}();
if (!success) {
@@ -2615,6 +2638,24 @@ bool HloParser::ParseSliceRanges(SliceRanges* result) {
return ParseToken(TokKind::kRbrace, "expects '}' to end ranges");
}
+// precisionlist ::= start precision_elements end
+// precision_elements
+// ::= /*empty*/
+// ::= precision_val (delim precision_val)*
+bool HloParser::ParsePrecisionList(
+ std::vector<PrecisionConfigProto::Precision>* result) {
+ auto parse_and_add_item = [&]() {
+ PrecisionConfigProto::Precision item;
+ if (!ParsePrecision(&item)) {
+ return false;
+ }
+ result->push_back(item);
+ return true;
+ };
+ return ParseList(TokKind::kLbrace, TokKind::kRbrace, TokKind::kComma,
+ parse_and_add_item);
+}
+
// int64list ::= start int64_elements end
// int64_elements
// ::= /*empty*/
@@ -2941,6 +2982,23 @@ bool HloParser::ParseRandomDistribution(RandomDistribution* result) {
return true;
}
+bool HloParser::ParsePrecision(PrecisionConfigProto::Precision* result) {
+ VLOG(1) << "ParsePrecision";
+ if (lexer_.GetKind() != TokKind::kIdent) {
+ return TokenError("expects random distribution");
+ }
+ string val = lexer_.GetStrVal();
+ auto status_or_result = StringToPrecision(val);
+ if (!status_or_result.ok()) {
+ return TokenError(
+ Printf("expects precision but sees: %s, error: %s", val.c_str(),
+ status_or_result.status().error_message().c_str()));
+ }
+ *result = status_or_result.ValueOrDie();
+ lexer_.Lex();
+ return true;
+}
+
bool HloParser::ParseInt64(tensorflow::int64* result) {
VLOG(1) << "ParseInt64";
if (lexer_.GetKind() != TokKind::kInt) {