aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/transpose_folding.cc
diff options
context:
space:
mode:
authorGravatar David Majnemer <majnemer@google.com>2018-08-21 23:56:48 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-22 00:03:32 -0700
commite846c2bc7dbbb5acca2d82a15b822b1445cd1e0c (patch)
treed54c263042ed561418e4e589b254904ccfd24899 /tensorflow/compiler/xla/service/transpose_folding.cc
parent1b8eb8d0a58f5b53cbae31e24d34082bc228caa8 (diff)
[XLA] Expose a way to control dot/conv precision
This adds a field to the proto so that we may serialize it. On TPUs, we can simulate higher precision by splitting a float32 number into several bfloat16 numbers such that their sum closely approximates the original number. A tensor contraction operation like convolution or a dot product can be computed by forming several partial products which approximate the correct answer to a closer margin. PiperOrigin-RevId: 209720948
Diffstat (limited to 'tensorflow/compiler/xla/service/transpose_folding.cc')
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc2
1 files changed, 2 insertions, 0 deletions
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index 49e1f87319..530f40e4b2 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -109,6 +109,7 @@ Status FoldTransposeIntoDot(InstructionOperandsPair pair) {
std::unique_ptr<HloInstruction> new_dot = HloInstruction::CreateDot(
dot->shape(), new_lhs, new_rhs, new_dim_numbers);
+ new_dot->set_precision_config(dot->precision_config());
return dot->parent()->ReplaceWithNewInstruction(dot, std::move(new_dot));
}
@@ -178,6 +179,7 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto new_conv = HloInstruction::CreateConvolve(
convolution.shape(), new_lhs, new_rhs, convolution.window(), new_dnums);
+ new_conv->set_precision_config(convolution.precision_config());
TF_CHECK_OK(convolution.parent()->ReplaceWithNewInstruction(
&convolution, std::move(new_conv)));