diff options
author | 2018-08-23 19:02:56 -0700 | |
---|---|---|
committer | 2018-08-23 19:06:38 -0700 | |
commit | cffe562473d11ac5f12ac189b686bd5839850344 (patch) | |
tree | 266417e9a4f01fbb008beef9e8a2216e16062178 /tensorflow/compiler/xla/service | |
parent | 9a2dab1f73a8cd765f22b67809f4c7d20f343fab (diff) |
[XLA] Centralize BF16Normalization Tuple handling
As more tuple-producing ops are added, this probably makes more sense.
PiperOrigin-RevId: 210039265
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r-- | tensorflow/compiler/xla/service/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/bfloat16_normalization.cc | 28 |
2 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD index 364e6e0c45..a3e4d78400 100644 --- a/tensorflow/compiler/xla/service/BUILD +++ b/tensorflow/compiler/xla/service/BUILD @@ -99,6 +99,7 @@ cc_library( ":bfloat16_support", ":hlo", ":hlo_pass", + "//tensorflow/compiler/xla:shape_util", "//tensorflow/compiler/xla:status_macros", "//tensorflow/compiler/xla:xla_data_proto", "//tensorflow/core:lib", diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc index 16e99b5722..32573ed355 100644 --- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc +++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc @@ -18,6 +18,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/hlo_computation.h" #include "tensorflow/compiler/xla/service/hlo_instruction.h" #include "tensorflow/compiler/xla/service/hlo_opcode.h" +#include "tensorflow/compiler/xla/shape_util.h" #include "tensorflow/compiler/xla/status_macros.h" #include "tensorflow/compiler/xla/xla_data.pb.h" #include "tensorflow/core/lib/gtl/array_slice.h" @@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault { Status DefaultAction(HloInstruction* hlo) override; - // Special handling for cross-replica-sum and sort which can have a tuple - // output. - Status HandleCrossReplicaSum(HloInstruction* crs) override; - Status HandleSort(HloInstruction* sort) override; - static bool Run(HloComputation* computation, const BFloat16Support* bfloat16_support) { BFloat16NormalizationVisitor visitor(computation, bfloat16_support); @@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations( return Status::OK(); } -Status BFloat16NormalizationVisitor::HandleCrossReplicaSum( - HloInstruction* crs) { - if (!ShapeUtil::IsTuple(crs->shape())) { - return HandleInstruction(crs); - } else { - return HandleMultipleOutputs(crs); - } -} - -Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) { - if (!ShapeUtil::IsTuple(sort->shape())) { - return HandleInstruction(sort); - } else { - return HandleMultipleOutputs(sort); - } -} - Status BFloat16NormalizationVisitor::HandleMultipleOutputs( HloInstruction* hlo) { std::vector<PrimitiveType> operand_types(hlo->operand_count()); @@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) { hlo->opcode() == HloOpcode::kConditional) { return Status::OK(); } + if ((hlo->opcode() == HloOpcode::kSort || + hlo->opcode() == HloOpcode::kCrossReplicaSum) && + ShapeUtil::IsTuple(hlo->shape())) { + return HandleMultipleOutputs(hlo); + } return HandleInstruction(hlo); } |