aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service
diff options
context:
space:
mode:
authorGravatar Michael Kuperstein <mkuper@google.com>2018-08-23 19:02:56 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-08-23 19:06:38 -0700
commitcffe562473d11ac5f12ac189b686bd5839850344 (patch)
tree266417e9a4f01fbb008beef9e8a2216e16062178 /tensorflow/compiler/xla/service
parent9a2dab1f73a8cd765f22b67809f4c7d20f343fab (diff)
[XLA] Centralize BF16Normalization Tuple handling
As more tuple-producing ops are added, this probably makes more sense. PiperOrigin-RevId: 210039265
Diffstat (limited to 'tensorflow/compiler/xla/service')
-rw-r--r--tensorflow/compiler/xla/service/BUILD1
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_normalization.cc28
2 files changed, 7 insertions, 22 deletions
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index 364e6e0c45..a3e4d78400 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -99,6 +99,7 @@ cc_library(
":bfloat16_support",
":hlo",
":hlo_pass",
+ "//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
diff --git a/tensorflow/compiler/xla/service/bfloat16_normalization.cc b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
index 16e99b5722..32573ed355 100644
--- a/tensorflow/compiler/xla/service/bfloat16_normalization.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_normalization.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
@@ -34,11 +35,6 @@ class BFloat16NormalizationVisitor : public DfsHloVisitorWithDefault {
Status DefaultAction(HloInstruction* hlo) override;
- // Special handling for cross-replica-sum and sort which can have a tuple
- // output.
- Status HandleCrossReplicaSum(HloInstruction* crs) override;
- Status HandleSort(HloInstruction* sort) override;
-
static bool Run(HloComputation* computation,
const BFloat16Support* bfloat16_support) {
BFloat16NormalizationVisitor visitor(computation, bfloat16_support);
@@ -150,23 +146,6 @@ Status BFloat16NormalizationVisitor::ConvertCalledComputations(
return Status::OK();
}
-Status BFloat16NormalizationVisitor::HandleCrossReplicaSum(
- HloInstruction* crs) {
- if (!ShapeUtil::IsTuple(crs->shape())) {
- return HandleInstruction(crs);
- } else {
- return HandleMultipleOutputs(crs);
- }
-}
-
-Status BFloat16NormalizationVisitor::HandleSort(HloInstruction* sort) {
- if (!ShapeUtil::IsTuple(sort->shape())) {
- return HandleInstruction(sort);
- } else {
- return HandleMultipleOutputs(sort);
- }
-}
-
Status BFloat16NormalizationVisitor::HandleMultipleOutputs(
HloInstruction* hlo) {
std::vector<PrimitiveType> operand_types(hlo->operand_count());
@@ -380,6 +359,11 @@ Status BFloat16NormalizationVisitor::DefaultAction(HloInstruction* hlo) {
hlo->opcode() == HloOpcode::kConditional) {
return Status::OK();
}
+ if ((hlo->opcode() == HloOpcode::kSort ||
+ hlo->opcode() == HloOpcode::kCrossReplicaSum) &&
+ ShapeUtil::IsTuple(hlo->shape())) {
+ return HandleMultipleOutputs(hlo);
+ }
return HandleInstruction(hlo);
}