diff options
author | Sanjoy Das <sanjoy@google.com> | 2017-11-07 14:08:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-07 14:13:56 -0800 |
commit | d0de8738e3401bbc5fd142846b4fc124951e5e07 (patch) | |
tree | f9d31064fd04d56dd38fa81491296aa6017d0ec7 /tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc | |
parent | 4340c7ab49c6184ffb691df50e5b76712338cf69 (diff) |
Fix ParallelTaskAssigner's use of the HloPassPipeline interface
We were creating the ParallelTaskAssignment contained in ParallelTaskAssigner
with an unoptimized module and then trying to ParallelTaskAssigning::Run on an
optimized module. This meant that the flop counts in HloCostAnalysis were
cached using bogus HloInstruction* pointers, which meant our parallel task
assignment was not effective.
PiperOrigin-RevId: 174909618
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc index c2213c8f2e..4a62a80fac 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc @@ -101,11 +101,9 @@ class DefaultCostModel : public ParallelCostModel { const std::unique_ptr<HloCostAnalysis> cost_analysis_; }; - ParallelTaskAssignment::ParallelTaskAssignment( const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) { + const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) { VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism; // Run cost analysis on 'module'. auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size); @@ -153,7 +151,6 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount( StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) { XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY"); XLA_VLOG_LINES(3, module->ToString()); - // Compute target parallel task counts for all instructions in 'module'. HloToParallelTasks hlo_to_parallel_tasks; ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks); @@ -230,6 +227,9 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper( void ParallelTaskAssigner::ComputeTargetParallelTasks( HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) { + ParallelTaskAssignment parallel_task_assignment(max_parallelism_, + shape_size_function_, module); + // Compute parallel task counts for all instructions in 'module'. for (auto* computation : module->computations()) { if (computation->IsFusionComputation()) { @@ -238,7 +238,7 @@ void ParallelTaskAssigner::ComputeTargetParallelTasks( for (auto* instruction : computation->instructions()) { // Query ParallelTaskAssignment for target parallel task count. const int64 target_parallel_task_count = - parallel_task_assignment_.GetTargetParallelTaskCount(instruction); + parallel_task_assignment.GetTargetParallelTaskCount(instruction); if (target_parallel_task_count > 1) { hlo_to_parallel_tasks->insert( {instruction, target_parallel_task_count}); |