aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2017-11-07 14:08:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-07 14:13:56 -0800
commitd0de8738e3401bbc5fd142846b4fc124951e5e07 (patch)
treef9d31064fd04d56dd38fa81491296aa6017d0ec7 /tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
parent4340c7ab49c6184ffb691df50e5b76712338cf69 (diff)
Fix ParallelTaskAssigner's use of the HloPassPipeline interface
We were creating the ParallelTaskAssignment contained in ParallelTaskAssigner with an unoptimized module and then trying to ParallelTaskAssigning::Run on an optimized module. This meant that the flop counts in HloCostAnalysis were cached using bogus HloInstruction* pointers, which meant our parallel task assignment was not effective. PiperOrigin-RevId: 174909618
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h')
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h16
1 files changed, 7 insertions, 9 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
index e036da5784..5801ec8d27 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h
@@ -37,10 +37,9 @@ class ParallelTaskAssignment {
// 'shape_size': shape size function used by HloCostAnalysis during parallel
// task assignment.
// 'module': the containing HloModule.
- ParallelTaskAssignment(
- const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size,
- HloModule* module);
+ ParallelTaskAssignment(const int64 max_parallelism,
+ const HloCostAnalysis::ShapeSizeFunction& shape_size,
+ HloModule* module);
~ParallelTaskAssignment() {}
// Computes and returns the target parallel task count for 'instruction'.
@@ -63,11 +62,9 @@ class ParallelTaskAssigner : public HloPassInterface {
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
// task assignment.
- // 'module': the containing HloModule.
ParallelTaskAssigner(const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size,
- HloModule* module)
- : parallel_task_assignment_(max_parallelism, shape_size, module) {}
+ const HloCostAnalysis::ShapeSizeFunction& shape_size)
+ : max_parallelism_(max_parallelism), shape_size_function_(shape_size) {}
~ParallelTaskAssigner() override {}
tensorflow::StringPiece name() const override {
@@ -95,7 +92,8 @@ class ParallelTaskAssigner : public HloPassInterface {
void ComputeTargetParallelTasks(HloModule* module,
HloToParallelTasks* hlo_to_parallel_tasks);
- ParallelTaskAssignment parallel_task_assignment_;
+ int64 max_parallelism_;
+ HloCostAnalysis::ShapeSizeFunction shape_size_function_;
};
} // namespace cpu