aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar A. Unique TensorFlower <gardener@tensorflow.org>2017-11-10 08:52:44 -0800
committerGravatar Andrew Selle <aselle@andyselle.com>2017-11-10 16:14:42 -0800
commit8ca7c8f3d4e0c39ec699eaea68d60c94fb624426 (patch)
treea95daa19c665ff835e98cce958f9ca543a8661e5
parent8614ef614245cfcfdd09bda0d633d5aa4f6e856e (diff)
[XLA] Make TuplePointsToAnalysis and LogicalBufferAnalysis track nested fusion instructions.
PiperOrigin-RevId: 175295981
-rw-r--r--tensorflow/compiler/xla/service/logical_buffer_analysis.cc23
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.cc32
2 files changed, 48 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
index 02dc49e78c..6aca6ba385 100644
--- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
+++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc
@@ -23,6 +23,23 @@ limitations under the License.
namespace xla {
+namespace {
+
+// Gather fusion instructions from 'instruction' into 'fusion_instructions'.
+void GatherFusionInstructions(
+ HloInstruction* instruction,
+ std::vector<HloInstruction*>* fusion_instructions) {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ for (auto* fused : instruction->fused_instructions()) {
+ if (fused->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(fused, fusion_instructions);
+ }
+ }
+ fusion_instructions->push_back(instruction);
+}
+
+} // namespace
+
/* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>>
LogicalBufferAnalysis::Run(const HloModule* module) {
std::unique_ptr<LogicalBufferAnalysis> analysis(
@@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() {
// We filter out fusion computations, and get to them through fusion
// instructions. This is because it's possible to have orphaned (unreachable)
// fusion computations, and we don't want to try to assign buffers to those.
+ std::vector<HloInstruction*> fusion_instructions;
for (auto* computation : module_->MakeNonfusionComputations()) {
TF_RETURN_IF_ERROR(computation->Accept(this));
for (auto* instruction : computation->instructions()) {
if (instruction->opcode() != HloOpcode::kFusion) {
continue;
}
- TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ GatherFusionInstructions(instruction, &fusion_instructions);
}
}
+ for (auto* instruction : fusion_instructions) {
+ TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
index a1f9451dd4..0c84856647 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc
@@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index,
tree_.mutable_element(index)->tuple_sources.insert(tuple);
}
+namespace {
+
+// Gather fusion instructions from 'instruction' into 'fusion_instructions'.
+void GatherFusionInstructions(
+ HloInstruction* instruction,
+ std::vector<HloInstruction*>* fusion_instructions) {
+ CHECK_EQ(HloOpcode::kFusion, instruction->opcode());
+ for (auto* fused : instruction->fused_instructions()) {
+ if (fused->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(fused, fusion_instructions);
+ }
+ }
+ fusion_instructions->push_back(instruction);
+}
+
+} // namespace
+
/* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>>
TuplePointsToAnalysis::Run(const HloModule* module) {
auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module);
@@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() {
logical_buffer_aliases_.resize(
logical_buffer_analysis_->num_logical_buffers());
+ std::vector<HloInstruction*> fusion_instructions;
for (auto* computation : module_->MakeNonfusionComputations()) {
TF_RETURN_IF_ERROR(computation->Accept(this));
TF_RETURN_IF_ERROR(
PopulateDefinedBuffersAndAliases(computation->instructions()));
- // Run points-to analysis on fusion instructions in 'computation'.
for (auto* instruction : computation->instructions()) {
- if (instruction->opcode() != HloOpcode::kFusion) {
- continue;
+ if (instruction->opcode() == HloOpcode::kFusion) {
+ GatherFusionInstructions(instruction, &fusion_instructions);
}
- TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
- TF_RETURN_IF_ERROR(
- PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
}
}
+ // Run points-to analysis on fusion instructions in 'computation'.
+ for (auto* instruction : fusion_instructions) {
+ TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this));
+ TF_RETURN_IF_ERROR(
+ PopulateDefinedBuffersAndAliases(instruction->fused_instructions()));
+ }
XLA_VLOG_LINES(3, ToString());