aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler
diff options
context:
space:
mode:
authorGravatar Yuanzhong Xu <yuanzx@google.com>2018-02-22 12:27:37 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-22 12:38:15 -0800
commit30727a6b673ff64ea8b5ad8754dee598b829a4aa (patch)
treeec7394f39c1b83d77588e4fc003a52d34f1048fd /tensorflow/compiler
parent78916e73383da9860ccdf07018892acb558249d7 (diff)
[XLA] HLO BF16 propagation pass.
Using BFloat16Support provided by the backend to determine what precision is needed for each HloInstruction. If the implementation of some HLOs already reduces input precision to BF16, this pass can enable BF16 on more ops without affecting the result. PiperOrigin-RevId: 186656378
Diffstat (limited to 'tensorflow/compiler')
-rw-r--r--tensorflow/compiler/xla/service/BUILD32
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc334
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h119
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation_test.cc335
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_support.h2
5 files changed, 821 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 4a076ac090..37ca1b893a 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -119,6 +119,38 @@ tf_cc_test(
)
cc_library(
+ name = "bfloat16_propagation",
+ srcs = ["bfloat16_propagation.cc"],
+ hdrs = ["bfloat16_propagation.h"],
+ deps = [
+ ":bfloat16_support",
+ ":hlo",
+ ":hlo_dataflow_analysis",
+ ":hlo_pass",
+ "//tensorflow/compiler/xla:shape_tree",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:util",
+ "//tensorflow/core:lib",
+ ],
+)
+
+tf_cc_test(
+ name = "bfloat16_propagation_test",
+ srcs = ["bfloat16_propagation_test.cc"],
+ deps = [
+ ":bfloat16_propagation",
+ ":bfloat16_support",
+ ":hlo",
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:test",
+ "//tensorflow/compiler/xla:test_helpers",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/tests:hlo_test_base",
+ "//tensorflow/compiler/xla/tests:xla_internal_test_main", # fixdeps: keep
+ ],
+)
+
+cc_library(
name = "shape_inference",
srcs = ["shape_inference.cc"],
hdrs = ["shape_inference.h"],
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
new file mode 100644
index 0000000000..9246cb25d2
--- /dev/null
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+
+#include "tensorflow/compiler/xla/map_util.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/shape_tree.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace xla {
+
+BFloat16Propagation::BFloat16Propagation(
+ const BFloat16Support* bfloat16_support)
+ : bfloat16_support_(bfloat16_support) {}
+
+void BFloat16Propagation::DetermineAndMutateFusionComputationPrecision(
+ HloInstruction* fusion) {
+ CHECK_EQ(fusion->opcode(), HloOpcode::kFusion);
+ if (!bfloat16_support_->SupportsMixedPrecisions(*fusion)) {
+ return;
+ }
+
+ // We are depending on the fusion node itself having already been analyzed
+ // for whether it can output BF16 and this has been adjusted in the output
+ // shape, and now we're looking to update the interior of the fusion node to
+ // match the new output shape, as well as recursively process the whole fusion
+ // node even if the output shape was not modified.
+ auto root = fusion->fused_instructions_computation()->root_instruction();
+
+ // Adjust root's element types according to the fusion's output shape.
+ ShapeUtil::ForEachMutableSubshape(
+ root->mutable_shape(), [&](Shape* subshape, const ShapeIndex& index) {
+ if (subshape->element_type() != F32) {
+ return;
+ }
+ if (ShapeUtil::GetSubshape(fusion->shape(), index).element_type() ==
+ BF16) {
+ subshape->set_element_type(BF16);
+ changed_ = true;
+ VLOG(2) << "Fused root " << root->ToString() << " at shape index "
+ << index << " changed to BF16 precision for fusion "
+ << fusion->ToString();
+ }
+ });
+
+ // Propagate BF16 in the fusion computation.
+ auto insts =
+ fusion->fused_instructions_computation()->MakeInstructionPostOrder();
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ DetermineAndMutateInstructionPrecision(*inst_it, /*skip_parameters=*/false);
+ }
+}
+
+void BFloat16Propagation::AdjustFusionParameters(HloInstruction* fusion) {
+ CHECK_EQ(fusion->fused_parameters().size(), fusion->operand_count());
+ for (int64 i = 0; i < fusion->operand_count(); ++i) {
+ auto parameter = fusion->fused_parameter(i);
+ ShapeUtil::ForEachMutableSubshape(
+ parameter->mutable_shape(),
+ [&](Shape* subshape, const ShapeIndex& index) {
+ if (!ShapeUtil::IsLeafIndex(parameter->shape(), index)) {
+ return;
+ }
+ PrimitiveType operand_type =
+ ShapeUtil::GetSubshape(fusion->operand(i)->shape(), index)
+ .element_type();
+ if (subshape->element_type() == operand_type) {
+ return;
+ }
+ CHECK(operand_type == F32 || operand_type == BF16);
+ subshape->set_element_type(operand_type);
+ changed_ = true;
+ VLOG(2) << "Fused parameter " << parameter->ToString()
+ << " at shape index " << index
+ << " adjusted to match operand in fusion "
+ << fusion->ToString();
+ });
+ }
+}
+
+bool BFloat16Propagation::AllUsersConsumeBF16(const HloInstruction& hlo,
+ const ShapeIndex& index) const {
+ auto value_set = dataflow_->GetValueSet(&hlo, index);
+ for (const HloValue* value : value_set.values()) {
+ if (ContainsKey(values_that_must_be_kept_as_f32_, value)) {
+ return false;
+ }
+ if (value->shape().element_type() == BF16) {
+ continue;
+ }
+ for (const HloUse& use : value->uses()) {
+ if (use.instruction->opcode() == HloOpcode::kFusion) {
+ auto fused_parameter =
+ use.instruction->fused_parameter(use.operand_number);
+ if (ShapeUtil::GetSubshape(fused_parameter->shape(), use.operand_index)
+ .element_type() != BF16) {
+ return false;
+ }
+ continue;
+ }
+ if (bfloat16_support_->EffectiveOperandPrecisionIsBF16(
+ *use.instruction, use.operand_number)) {
+ continue;
+ }
+ // If the op propagates precision and it outputs a BF16, then it's OK to
+ // supply BF16 also as the input. In the backward mutation pass, the users
+ // shapes should have already been processed.
+ PrimitiveType user_output_type = PRIMITIVE_TYPE_INVALID;
+ if (use.instruction->opcode() == HloOpcode::kTuple ||
+ (use.instruction->opcode() == HloOpcode::kCrossReplicaSum &&
+ ShapeUtil::IsTuple(use.instruction->shape()))) {
+ user_output_type = ShapeUtil::GetSubshape(
+ ShapeUtil::GetSubshape(use.instruction->shape(),
+ {use.operand_number}),
+ use.operand_index)
+ .element_type();
+ } else {
+ user_output_type = use.instruction->shape().element_type();
+ }
+ if (bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(
+ *use.instruction, use.operand_number) &&
+ user_output_type == BF16) {
+ continue;
+ }
+ return false;
+ }
+ }
+ return true;
+}
+
+void BFloat16Propagation::DetermineAndMutateInstructionPrecision(
+ HloInstruction* hlo, bool skip_parameters) {
+ // We handle any fusion computation after the instruction is handled, because
+ // we need to know a fusion's output shape before propagating inside its fused
+ // computation.
+ auto cleaner = tensorflow::gtl::MakeCleanup([this, hlo] {
+ if (hlo->opcode() == HloOpcode::kFusion) {
+ DetermineAndMutateFusionComputationPrecision(hlo);
+ }
+ });
+
+ // Do not change precision for instructions related to entry and exit of a
+ // computation, and control flow, because this pass might break the interfaces
+ // or assumptions for them.
+ if (hlo->opcode() == HloOpcode::kInfeed || //
+ hlo->opcode() == HloOpcode::kOutfeed || //
+ hlo->opcode() == HloOpcode::kConstant || //
+ hlo->opcode() == HloOpcode::kCustomCall || //
+ hlo->opcode() == HloOpcode::kCall || //
+ hlo->opcode() == HloOpcode::kWhile || //
+ hlo->opcode() == HloOpcode::kConditional || //
+ (hlo->opcode() == HloOpcode::kParameter && skip_parameters)) {
+ return;
+ }
+
+ // Prevent root instructions from having their output modified by recording
+ // all F32 output values as needing to stay as F32.
+ CHECK(hlo->parent() != nullptr);
+ if (hlo == hlo->parent()->root_instruction()) {
+ if (!hlo->parent()->IsFusionComputation()) {
+ ShapeUtil::ForEachSubshape(hlo->shape(), [&](const Shape& subshape,
+ const ShapeIndex& index) {
+ if (subshape.element_type() != F32) {
+ return;
+ }
+ for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+ // Since we use HloValues from the dataflow analysis, this can also
+ // affect HLO instructions beyond the root, e.g., if the root is a
+ // Tuple HLO, then its operands are also affected.
+ values_that_must_be_kept_as_f32_.insert(value);
+ }
+ });
+ }
+ return;
+ }
+
+ if (!ContainsKey(consider_using_bfloat16_, hlo)) {
+ return;
+ }
+
+ if (!bfloat16_support_->SupportsBF16Output(*hlo)) {
+ return;
+ }
+
+ ShapeUtil::ForEachMutableSubshape(
+ hlo->mutable_shape(),
+ [hlo, this](Shape* subshape, const ShapeIndex& index) {
+ if (subshape->element_type() == F32 &&
+ AllUsersConsumeBF16(*hlo, index)) {
+ subshape->set_element_type(BF16);
+ changed_ = true;
+ VLOG(2) << "HloInstruction output at shape index " << index
+ << " changed to BF16 precision: " << hlo->ToString();
+ }
+ });
+}
+
+bool BFloat16Propagation::InstructionIsCandidateForBF16Output(
+ HloInstruction* hlo) {
+ if (!bfloat16_support_->SupportsMixedPrecisions(*hlo) &&
+ hlo->opcode() != HloOpcode::kTuple &&
+ hlo->opcode() != HloOpcode::kGetTupleElement &&
+ hlo->shape().element_type() != BF16) {
+ for (int64 i = 0; i < hlo->operand_count(); ++i) {
+ if (!bfloat16_support_->EffectiveOperandPrecisionIsOutputPrecision(*hlo,
+ i) ||
+ !ContainsKey(consider_using_bfloat16_, hlo->operand(i))) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
+// The algorithm first does a forward pass (parameters to root) to determine a
+// set of instructions to consider using bfloat16, then does a backward pass to
+// determine the precisions of those instructions according to the need of
+// their users.
+StatusOr<bool> BFloat16Propagation::Run(HloModule* module) {
+ TF_ASSIGN_OR_RETURN(dataflow_, HloDataflowAnalysis::Run(*module));
+
+ std::list<HloComputation*> computations_topological_order =
+ module->MakeComputationPostOrder();
+ // The first step is a forward pass (parameters to root), where we determine
+ // the potential candidate instructions to use bfloat16 in the outputs that
+ // are not likely to cause overhead from extra explicit conversions. This is
+ // done forwardly because we determine whether an HLO is a candidate partially
+ // based on whether its operands are candidates.
+ for (auto computation : computations_topological_order) {
+ for (auto inst : computation->MakeInstructionPostOrder()) {
+ if (InstructionIsCandidateForBF16Output(inst)) {
+ consider_using_bfloat16_.insert(inst);
+ }
+ }
+ }
+
+ // The second step is a backward pass (root to parameters), where we modify
+ // the precisions of the instructions identified in the first step when
+ // feasible. This is done backwardly because we determine the precision of an
+ // HLO's output based on how it is later used.
+ //
+ // The precision of an instruction is determined by its users, so we do the
+ // propagation in reverse topological order.
+ for (auto comp_it = computations_topological_order.rbegin();
+ comp_it != computations_topological_order.rend(); ++comp_it) {
+ if ((*comp_it)->IsFusionComputation()) {
+ // Fusion computations are handled when visiting the fusion instruction.
+ continue;
+ }
+ auto insts = (*comp_it)->MakeInstructionPostOrder();
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ DetermineAndMutateInstructionPrecision(*inst_it,
+ /*skip_parameters=*/true);
+ }
+ }
+
+ if (!changed_) {
+ return false;
+ }
+
+ // It's possible that an instruction does not define a buffer, but the
+ // defining instruction's shape has changed. So we need to adjust the output
+ // shapes of instructions according to the HLO values they refer to.
+ for (auto comp_it = computations_topological_order.rbegin();
+ comp_it != computations_topological_order.rend(); ++comp_it) {
+ auto insts = (*comp_it)->MakeInstructionPostOrder();
+ // Do the adjustment on each instruction in the computation in reverse
+ // topological order.
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ auto hlo = *inst_it;
+ auto adjust_buffer = [this, hlo](Shape* subshape,
+ const ShapeIndex& index) {
+ if (subshape->element_type() != F32 &&
+ subshape->element_type() != BF16) {
+ return;
+ }
+ PrimitiveType type = BF16;
+ for (const auto* value : dataflow_->GetValueSet(hlo, index).values()) {
+ if (value->shape().element_type() == BF16) {
+ continue;
+ }
+ CHECK_EQ(value->shape().element_type(), F32);
+ type = F32;
+ break;
+ }
+ // It's possible that a user has been changed from BF16 to F32
+ // during this final adjustment pass, so we need to check
+ // AllUsersConsumeBF16() again.
+ if (type == BF16 && !AllUsersConsumeBF16(*hlo, index)) {
+ type = F32;
+ }
+ if (type == F32) {
+ for (const auto* value :
+ dataflow_->GetValueSet(hlo, index).values()) {
+ // We rely on the fact that this adjustment works in reverse
+ // topological order. Adding the value to
+ // values_that_must_be_kept_as_f32_ will ensure the correctness
+ // of the adjustment for HLOs that will be processed later.
+ values_that_must_be_kept_as_f32_.insert(value);
+ }
+ }
+ subshape->set_element_type(type);
+ };
+ ShapeUtil::ForEachMutableSubshape(hlo->mutable_shape(), adjust_buffer);
+ }
+ // Now adjust parameters of fusions inside this computation.
+ for (auto inst_it = insts.rbegin(); inst_it != insts.rend(); ++inst_it) {
+ auto hlo = *inst_it;
+ if (hlo->opcode() == HloOpcode::kFusion) {
+ AdjustFusionParameters(hlo);
+ }
+ }
+ }
+ return true;
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
new file mode 100644
index 0000000000..aa81dde3b0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -0,0 +1,119 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
+#define TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
+
+#include <memory>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
+
+namespace xla {
+
+// HLO pass which reduces the precision of some HLO instructions to BF16
+// according to the backend-specific BFloat16Support rule provided by the
+// caller.
+//
+// This pass can be used to reduce instruction precision without affecting the
+// numerical accuracy of the module, i.e., the final output of the module would
+// be bitwise identical to that without this pass; this is possible if the
+// backend already reduces precision to BF16 on some HLO instructions.
+//
+// This pass will not modify the signature of any non-fusion computation.
+//
+// !!! WARNING !!! This pass can introduce mixed precision in individual HLOs,
+// which has two issues:
+//
+// 1) It does not guarantee to respect the passed-in BFloat16Support
+// specification in terms of mixed precision, so the backend may not support an
+// HLO that has mixed precision produced by this pass. To address this issue,
+// run BFloat16Normalization with the same BFloat16Support after this pass.
+//
+// 2) In general, mixed precision may break the assumptions of some other HLO
+// passes even if the specific backend supports the individual HLOs. Such
+// assumptions include that there are no HLOs using mixed precision, or that the
+// precision of an HLO's output is determined by its inputs. It should be used
+// at the end of the HLO optimization pipeline but before
+// BFloat16ConversionFolding. If other passes are needed after this pass, run
+// BFloat16MixedPrecisionRemoval first to undo some of the changes made by this
+// pass.
+class BFloat16Propagation : public HloPassInterface {
+ public:
+ explicit BFloat16Propagation(const BFloat16Support* bfloat16_support);
+
+ ~BFloat16Propagation() override = default;
+
+ tensorflow::StringPiece name() const override {
+ return "bfloat16-propagation";
+ }
+
+ // Runs the pass on the given module. Returns whether the module was changed
+ // (precision reductions were added).
+ StatusOr<bool> Run(HloModule* module) override;
+
+ private:
+ // ***************************
+ // Function called and state produced by the forward analysis pass (from
+ // parameters to root) that determines the candidate HLOs to use BF16 outputs.
+
+ // Determines whether we should consider changing the precision of the given
+ // instruction in the forward pass.
+ bool InstructionIsCandidateForBF16Output(HloInstruction* hlo);
+
+ // The set of instructions to consider using bfloat16, computed in the forward
+ // pass.
+ tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
+
+ // ***************************
+ // Functions called and state produced by the backward mutation pass (from
+ // root to parameters).
+
+ // Determines the precision for the given instruction in the mutation pass.
+ void DetermineAndMutateInstructionPrecision(HloInstruction* hlo,
+ bool skip_parameters);
+
+ // Special handling in the mutation pass for fusion computations.
+ void DetermineAndMutateFusionComputationPrecision(HloInstruction* fusion);
+
+ // Makes the fusion parameters match the precision of the actual parameters
+ // passed to the fusion node.
+ void AdjustFusionParameters(HloInstruction* fusion);
+
+ // Returns whether all uses of the given HloInstruction can consume BF16
+ // input.
+ bool AllUsersConsumeBF16(const HloInstruction& hlo,
+ const ShapeIndex& index) const;
+
+ // The set of F32 HLO values that must be kept in F32.
+ tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
+
+ // ***************************
+ // State used by both passes.
+ const BFloat16Support* bfloat16_support_;
+ std::unique_ptr<HloDataflowAnalysis> dataflow_;
+
+ bool changed_ = false;
+};
+
+} // namespace xla
+
+#endif // TENSORFLOW_COMPILER_XLA_SERVICE_BFLOAT16_PROPAGATION_H_
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
new file mode 100644
index 0000000000..4c86c6b26e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation_test.cc
@@ -0,0 +1,335 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+#include "tensorflow/compiler/xla/service/bfloat16_support.h"
+#include "tensorflow/compiler/xla/service/hlo_computation.h"
+#include "tensorflow/compiler/xla/service/hlo_instruction.h"
+#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
+#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/compiler/xla/test_helpers.h"
+#include "tensorflow/compiler/xla/tests/hlo_test_base.h"
+#include "tensorflow/compiler/xla/xla_data.pb.h"
+
+namespace xla {
+
+// A class specifying the BF16 support used to test the propagation pass. It
+// specifies that BF16 and mixed precision are supported in all HloInstructions,
+// and that kDot reduces its operands precision to BF16.
+class TestBFloat16Support : public BFloat16Support {
+ public:
+ TestBFloat16Support() {}
+ ~TestBFloat16Support() override {}
+
+ bool SupportsBF16Operand(const HloInstruction& hlo,
+ int64 operand_index) const override {
+ return true;
+ }
+
+ bool SupportsBF16Output(const HloInstruction& hlo) const override {
+ return true;
+ }
+
+ bool SupportsMixedPrecisions(const HloInstruction& hlo) const override {
+ return true;
+ }
+
+ bool EffectiveOperandPrecisionIsBF16(const HloInstruction& hlo,
+ int64 operand_index) const override {
+ return hlo.opcode() == HloOpcode::kDot;
+ }
+};
+
+class BFloat16PropagationTest : public HloTestBase {
+ protected:
+ // Runs the propagation pass on the given module, and returns whether the
+ // module is changed after this pass.
+ bool PropagatePrecision(HloModule* module) {
+ TestBFloat16Support bfloat16_support;
+ BFloat16Propagation propagation(&bfloat16_support);
+ StatusOr<bool> result = propagation.Run(module);
+ EXPECT_IS_OK(result.status());
+ return result.ValueOrDie();
+ }
+
+ // Returns whether the given HloInstruction's output element type is BF16 or
+ // the only use of it is converting to BF16.
+ bool OutputsBF16(HloInstruction* inst) {
+ if (inst->shape().element_type() == BF16) {
+ return true;
+ }
+ return inst->user_count() == 1 &&
+ inst->users()[0]->opcode() == HloOpcode::kConvert &&
+ inst->users()[0]->shape().element_type() == BF16;
+ }
+};
+
+// Tests that BF16 can propagate through select over non-tuple buffers, but not
+// through add where reducing operand precision can affect the result.
+TEST_F(BFloat16PropagationTest, PropagateThroughSelectButNotAdd) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* a =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* c =
+ builder.AddInstruction(HloInstruction::CreateParameter(2, shape, "c"));
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, b));
+ HloInstruction* pred = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kEq, a, b));
+ HloInstruction* sel = builder.AddInstruction(
+ HloInstruction::CreateTernary(shape, HloOpcode::kSelect, pred, c, add1));
+ HloInstruction* xpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 2}), sel, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, a));
+ HloInstruction* root = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, dot, dot));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), root);
+ EXPECT_TRUE(OutputsBF16(xpose));
+ EXPECT_TRUE(OutputsBF16(sel));
+ EXPECT_TRUE(OutputsBF16(add1));
+ EXPECT_FALSE(OutputsBF16(add0));
+ EXPECT_FALSE(OutputsBF16(a));
+ EXPECT_FALSE(OutputsBF16(b));
+ EXPECT_FALSE(OutputsBF16(c));
+}
+
+// Tests that BF16 can be propagated through nested tuples.
+TEST_F(BFloat16PropagationTest, PropagateThroughTuples) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* a =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
+ HloInstruction* add2 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, b, b));
+ HloInstruction* xpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
+
+ HloInstruction* tuple0 =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1, add2}));
+ HloInstruction* tuple1 =
+ builder.AddInstruction(HloInstruction::CreateTuple({tuple0, xpose}));
+
+ HloInstruction* lhs = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(xpose->shape(), tuple1, 1));
+ HloInstruction* rhs =
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ add0->shape(),
+ builder.AddInstruction(HloInstruction::CreateGetTupleElement(
+ tuple0->shape(), tuple1, 0)),
+ 0));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+
+ HloInstruction* output_tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({dot, add2}));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), output_tuple);
+ EXPECT_TRUE(OutputsBF16(xpose));
+ EXPECT_TRUE(OutputsBF16(add0));
+ EXPECT_TRUE(OutputsBF16(add1));
+ EXPECT_FALSE(OutputsBF16(add2));
+}
+
+// Tests that even if an instruction does not define a buffer in its output, its
+// shape must match the defining instruction.
+TEST_F(BFloat16PropagationTest, SameValueReferencedTwice) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* a =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, a));
+
+ HloInstruction* lhs = builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 2}), add1, {1, 0}));
+
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add0, add1}));
+ HloInstruction* rhs = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(add1->shape(), tuple, 1));
+
+ // lhs is the transpose of add1, and rhs is a get-tuple-element aliasing add1.
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, lhs, rhs));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_TRUE(OutputsBF16(add0));
+ EXPECT_TRUE(OutputsBF16(add1));
+ EXPECT_TRUE(OutputsBF16(lhs));
+ // rhs is a get-tuple-element, which does not define a buffer, but its shape
+ // should also be adjusted accordingly.
+ EXPECT_TRUE(OutputsBF16(rhs));
+}
+
+// Tests that a non-fusion computation's root should not be changed.
+TEST_F(BFloat16PropagationTest, DoNotChangeComputationRoot) {
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* a =
+ builder.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b =
+ builder.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, a, b));
+
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, add, add));
+
+ HloInstruction* tuple =
+ builder.AddInstruction(HloInstruction::CreateTuple({add, dot}));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_FALSE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), tuple);
+ EXPECT_FALSE(OutputsBF16(add));
+}
+
+// Tests that BF16 is propagated properly through fused computations.
+TEST_F(BFloat16PropagationTest, PropagateThroughFusion) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* add = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
+
+ auto builder_f0 = HloComputation::Builder("fusion0");
+ HloInstruction* a_f0 =
+ builder_f0.AddInstruction(HloInstruction::CreateParameter(0, shape, "a"));
+ HloInstruction* b_f0 =
+ builder_f0.AddInstruction(HloInstruction::CreateParameter(1, shape, "b"));
+ HloInstruction* tuple_f0 =
+ builder_f0.AddInstruction(HloInstruction::CreateTuple({a_f0, b_f0}));
+ auto comp_f0 = module->AddEmbeddedComputation(builder_f0.Build());
+ auto fusion0 = builder.AddInstruction(HloInstruction::CreateFusion(
+ tuple_f0->shape(), HloInstruction::FusionKind::kCustom, {add, add},
+ comp_f0));
+
+ auto builder_f1 = HloComputation::Builder("fusion1");
+ HloInstruction* p_f1 = builder_f1.AddInstruction(
+ HloInstruction::CreateParameter(0, tuple_f0->shape(), "param"));
+ HloInstruction* a_f1 = builder_f1.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, p_f1, 0));
+ HloInstruction* b_f1 = builder_f1.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, p_f1, 1));
+ HloInstruction* dot = builder_f1.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, a_f1, b_f1));
+ auto comp_f1 = module->AddEmbeddedComputation(builder_f1.Build());
+ auto fusion1 = builder.AddInstruction(HloInstruction::CreateFusion(
+ dot->shape(), HloInstruction::FusionKind::kCustom, {fusion0}, comp_f1));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), fusion1);
+ EXPECT_TRUE(OutputsBF16(add));
+ EXPECT_TRUE(OutputsBF16(a_f0));
+ EXPECT_TRUE(OutputsBF16(b_f0));
+ EXPECT_TRUE(OutputsBF16(a_f1));
+ EXPECT_TRUE(OutputsBF16(b_f1));
+}
+
+// A select over tuples does not define the leaf buffers, so the types in
+// on_true and on_false must match, so that as long as one of them is F32, the
+// other must be F32 as well.
+TEST_F(BFloat16PropagationTest, SelectOverTuples) {
+ auto module = CreateNewModule();
+ auto builder = HloComputation::Builder(TestName());
+ Shape shape = ShapeUtil::MakeShape(F32, {2, 4});
+
+ HloInstruction* param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, shape, "param"));
+ HloInstruction* pred = builder.AddInstruction(HloInstruction::CreateParameter(
+ 1, ShapeUtil::MakeShape(PRED, {}), "pred"));
+
+ HloInstruction* add0 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, param, param));
+ HloInstruction* add1 = builder.AddInstruction(
+ HloInstruction::CreateBinary(shape, HloOpcode::kAdd, add0, param));
+ HloInstruction* tuple0 =
+ builder.AddInstruction(HloInstruction::CreateTuple({param, add0}));
+ HloInstruction* tuple1 =
+ builder.AddInstruction(HloInstruction::CreateTuple({param, add1}));
+ HloInstruction* sel = builder.AddInstruction(HloInstruction::CreateTernary(
+ tuple0->shape(), HloOpcode::kSelect, pred, tuple0, tuple1));
+ HloInstruction* gte0 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, sel, 0));
+ HloInstruction* gte1 = builder.AddInstruction(
+ HloInstruction::CreateGetTupleElement(shape, sel, 1));
+ HloInstruction* xpose =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {4, 2}), gte0, {1, 0}));
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
+ ShapeUtil::MakeShape(F32, {4, 4}), HloOpcode::kDot, xpose, gte1));
+
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_TRUE(PropagatePrecision(module.get()));
+
+ EXPECT_EQ(computation->root_instruction(), dot);
+ EXPECT_FALSE(OutputsBF16(add0));
+ EXPECT_FALSE(OutputsBF16(add1));
+ EXPECT_FALSE(OutputsBF16(gte0));
+ EXPECT_FALSE(OutputsBF16(gte1));
+ EXPECT_TRUE(OutputsBF16(xpose));
+}
+
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.h b/tensorflow/compiler/xla/service/bfloat16_support.h
index 29f662d22b..82c2745f44 100644
--- a/tensorflow/compiler/xla/service/bfloat16_support.h
+++ b/tensorflow/compiler/xla/service/bfloat16_support.h
@@ -39,7 +39,7 @@ class BFloat16Support {
// precisions (BF16 and F32).
virtual bool SupportsMixedPrecisions(const HloInstruction& hlo) const;
- // Returns whether the given HLO inherits its BF16 operand precision at the
+ // Returns whether the given HLO preserves its BF16 operand precision at the
// given index, so even if the output is F32, elements in the output that
// depend on the BF16 operand will still have BF16 effective precision even if
// they have F32 format. Similarly, this also means if the output is BF16 then