diff options
author | Sanjoy Das <sanjoy@google.com> | 2017-11-07 14:08:01 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2017-11-07 14:13:56 -0800 |
commit | d0de8738e3401bbc5fd142846b4fc124951e5e07 (patch) | |
tree | f9d31064fd04d56dd38fa81491296aa6017d0ec7 /tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h | |
parent | 4340c7ab49c6184ffb691df50e5b76712338cf69 (diff) |
Fix ParallelTaskAssigner's use of the HloPassPipeline interface
We were creating the ParallelTaskAssignment contained in ParallelTaskAssigner
with an unoptimized module and then trying to ParallelTaskAssigning::Run on an
optimized module. This meant that the flop counts in HloCostAnalysis were
cached using bogus HloInstruction* pointers, which meant our parallel task
assignment was not effective.
PiperOrigin-RevId: 174909618
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h')
-rw-r--r-- | tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h | 16 |
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h index e036da5784..5801ec8d27 100644 --- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h +++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h @@ -37,10 +37,9 @@ class ParallelTaskAssignment { // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. // 'module': the containing HloModule. - ParallelTaskAssignment( - const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module); + ParallelTaskAssignment(const int64 max_parallelism, + const HloCostAnalysis::ShapeSizeFunction& shape_size, + HloModule* module); ~ParallelTaskAssignment() {} // Computes and returns the target parallel task count for 'instruction'. @@ -63,11 +62,9 @@ class ParallelTaskAssigner : public HloPassInterface { // 'max_parallelism': the maximum parallel task count per instruction. // 'shape_size': shape size function used by HloCostAnalysis during parallel // task assignment. - // 'module': the containing HloModule. ParallelTaskAssigner(const int64 max_parallelism, - const HloCostAnalysis::ShapeSizeFunction& shape_size, - HloModule* module) - : parallel_task_assignment_(max_parallelism, shape_size, module) {} + const HloCostAnalysis::ShapeSizeFunction& shape_size) + : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {} ~ParallelTaskAssigner() override {} tensorflow::StringPiece name() const override { @@ -95,7 +92,8 @@ class ParallelTaskAssigner : public HloPassInterface { void ComputeTargetParallelTasks(HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks); - ParallelTaskAssignment parallel_task_assignment_; + int64 max_parallelism_; + HloCostAnalysis::ShapeSizeFunction shape_size_function_; }; } // namespace cpu |