aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/hlo_instruction.h
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-21 23:56:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 00:03:32 -0700
commite846c2bc7dbbb5acca2d82a15b822b1445cd1e0c (patch)
treed54c263042ed561418e4e589b254904ccfd24899 /tensorflow/compiler/xla/service/hlo_instruction.h
parent1b8eb8d0a58f5b53cbae31e24d34082bc228caa8 (diff)
[XLA] Expose a way to control dot/conv precision
This adds a field to the proto so that we may serialize it. On TPUs, we can simulate higher precision by splitting a float32 number into several bfloat16 numbers such that their sum closely approximates the original number. A tensor contraction operation like convolution or a dot product can be computed by forming several partial products which approximate the correct answer to a closer margin. PiperOrigin-RevId: 209720948
Diffstat (limited to 'tensorflow/compiler/xla/service/hlo_instruction.h')
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h31
1 files changed, 30 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 69397a4b37..21710bd31d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -102,6 +102,7 @@ class HloPrintOptions {
return HloPrintOptions()
.set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
.set_print_metadata(false)
+ .set_print_backend_config(false)
.set_compact_operands(true)
.set_print_operand_shape(true)
.set_print_program_shape(false)
@@ -183,7 +184,7 @@ class HloPrintOptions {
return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
- bool print_backend_config() const { return print_metadata_; }
+ bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
@@ -858,6 +859,11 @@ class HloInstruction {
return false;
}
+ if (!ContainersEqual(precision_config_.operand_precision(),
+ other.precision_config_.operand_precision())) {
+ return false;
+ }
+
return IdenticalSlowPath(other, eq_computations);
}
@@ -1105,6 +1111,9 @@ class HloInstruction {
// Returns the dump string of the dot dimension numbers.
string DotDimensionNumbersToString() const;
+ // Returns the dump string of the precision configuration.
+ string PrecisionConfigToString() const;
+
// Clones the HLO instruction. The clone will have the same opcode, shape, and
// operands. After creation the clone has no uses. "this" (the instruction
// cloned from) is not changed. Suffix is the string to append to the name of
@@ -1248,6 +1257,20 @@ class HloInstruction {
static StatusOr<string> BackendConfigToRawString(
const tensorflow::protobuf::Message& proto);
+ // Returns the information used to tell the implementation information about
+ // what sort of precision is requested. The meaning of the field is backend
+ // specific. At the moment, it is only supported for kConvolution and kDot.
+ // Transformations on one kDot or kConvolution to another will preserve this
+ // information. Transformations to other HLOs will not preserve this
+ // information but it is presumed that the alternate lowering is strictly
+ // superior.
+ const PrecisionConfigProto& precision_config() const {
+ return precision_config_;
+ }
+ void set_precision_config(const PrecisionConfigProto& precision_config) {
+ precision_config_ = precision_config;
+ }
+
// Sets the debug metadata for this instruction.
void set_metadata(const OpMetadata& metadata) { metadata_ = metadata; }
const OpMetadata& metadata() const { return metadata_; }
@@ -1653,6 +1676,10 @@ class HloInstruction {
// HLO. See the documentation on backend_config().
string backend_config_;
+ // Information used to communicate to the implementation about the algorithm
+ // used to produce results. See the documentation on precision_config().
+ PrecisionConfigProto precision_config_;
+
// String identifier for instruction.
string name_;
@@ -1675,10 +1702,12 @@ StatusOr<HloInstruction::FusionKind> StringToFusionKind(
string PaddingConfigToString(const PaddingConfig& padding);
string OpMetadataToString(const OpMetadata& metadata);
string RandomDistributionToString(const RandomDistribution& distribution);
+string PrecisionToString(const PrecisionConfigProto::Precision& precision);
string ConvolutionDimensionNumbersToString(
const ConvolutionDimensionNumbers& dnums);
StatusOr<RandomDistribution> StringToRandomDistribution(const string& name);
+StatusOr<PrecisionConfigProto::Precision> StringToPrecision(const string& name);
std::ostream& operator<<(std::ostream& os, HloInstruction::FusionKind kind);