diff options
author | A. Unique TensorFlower <gardener@tensorflow.org> | 2017-11-10 08:52:44 -0800 |
---|---|---|
committer | Andrew Selle <aselle@andyselle.com> | 2017-11-10 16:14:42 -0800 |
commit | 8ca7c8f3d4e0c39ec699eaea68d60c94fb624426 (patch) | |
tree | a95daa19c665ff835e98cce958f9ca543a8661e5 | |
parent | 8614ef614245cfcfdd09bda0d633d5aa4f6e856e (diff) |
[XLA] Make TuplePointsToAnalysis and LogicalBufferAnalysis track nested fusion instructions.
PiperOrigin-RevId: 175295981
-rw-r--r-- | tensorflow/compiler/xla/service/logical_buffer_analysis.cc | 23 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/tuple_points_to_analysis.cc | 32 |
2 files changed, 48 insertions, 7 deletions
diff --git a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc index 02dc49e78c..6aca6ba385 100644 --- a/tensorflow/compiler/xla/service/logical_buffer_analysis.cc +++ b/tensorflow/compiler/xla/service/logical_buffer_analysis.cc @@ -23,6 +23,23 @@ limitations under the License. namespace xla { +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector<HloInstruction*>* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr<std::unique_ptr<LogicalBufferAnalysis>> LogicalBufferAnalysis::Run(const HloModule* module) { std::unique_ptr<LogicalBufferAnalysis> analysis( @@ -41,15 +58,19 @@ Status LogicalBufferAnalysis::Analyze() { // We filter out fusion computations, and get to them through fusion // instructions. This is because it's possible to have orphaned (unreachable) // fusion computations, and we don't want to try to assign buffers to those. + std::vector<HloInstruction*> fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); for (auto* instruction : computation->instructions()) { if (instruction->opcode() != HloOpcode::kFusion) { continue; } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + GatherFusionInstructions(instruction, &fusion_instructions); } } + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + } return Status::OK(); } diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc index a1f9451dd4..0c84856647 100644 --- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc +++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.cc @@ -120,6 +120,23 @@ void PointsToSet::add_tuple_source(const ShapeIndex& index, tree_.mutable_element(index)->tuple_sources.insert(tuple); } +namespace { + +// Gather fusion instructions from 'instruction' into 'fusion_instructions'. +void GatherFusionInstructions( + HloInstruction* instruction, + std::vector<HloInstruction*>* fusion_instructions) { + CHECK_EQ(HloOpcode::kFusion, instruction->opcode()); + for (auto* fused : instruction->fused_instructions()) { + if (fused->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(fused, fusion_instructions); + } + } + fusion_instructions->push_back(instruction); +} + +} // namespace + /* static */ StatusOr<std::unique_ptr<TuplePointsToAnalysis>> TuplePointsToAnalysis::Run(const HloModule* module) { auto logical_buffer_analysis = LogicalBufferAnalysis::Run(module); @@ -137,20 +154,23 @@ Status TuplePointsToAnalysis::Analyze() { logical_buffer_aliases_.resize( logical_buffer_analysis_->num_logical_buffers()); + std::vector<HloInstruction*> fusion_instructions; for (auto* computation : module_->MakeNonfusionComputations()) { TF_RETURN_IF_ERROR(computation->Accept(this)); TF_RETURN_IF_ERROR( PopulateDefinedBuffersAndAliases(computation->instructions())); - // Run points-to analysis on fusion instructions in 'computation'. for (auto* instruction : computation->instructions()) { - if (instruction->opcode() != HloOpcode::kFusion) { - continue; + if (instruction->opcode() == HloOpcode::kFusion) { + GatherFusionInstructions(instruction, &fusion_instructions); } - TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); - TF_RETURN_IF_ERROR( - PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); } } + // Run points-to analysis on fusion instructions in 'computation'. + for (auto* instruction : fusion_instructions) { + TF_RETURN_IF_ERROR(instruction->fused_expression_root()->Accept(this)); + TF_RETURN_IF_ERROR( + PopulateDefinedBuffersAndAliases(instruction->fused_instructions())); + } XLA_VLOG_LINES(3, ToString()); |