aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2017-11-07 14:08:01 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2017-11-07 14:13:56 -0800
commitd0de8738e3401bbc5fd142846b4fc124951e5e07 (patch)
treef9d31064fd04d56dd38fa81491296aa6017d0ec7 /tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
parent4340c7ab49c6184ffb691df50e5b76712338cf69 (diff)
Fix ParallelTaskAssigner's use of the HloPassPipeline interface
We were creating the ParallelTaskAssignment contained in ParallelTaskAssigner with an unoptimized module and then trying to ParallelTaskAssigning::Run on an optimized module. This meant that the flop counts in HloCostAnalysis were cached using bogus HloInstruction* pointers, which meant our parallel task assignment was not effective. PiperOrigin-RevId: 174909618
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc10
1 files changed, 5 insertions, 5 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index c2213c8f2e..4a62a80fac 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -101,11 +101,9 @@ class DefaultCostModel : public ParallelCostModel {
const std::unique_ptr<HloCostAnalysis> cost_analysis_;
};
-
ParallelTaskAssignment::ParallelTaskAssignment(
const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size,
- HloModule* module) {
+ const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
@@ -153,7 +151,6 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
StatusOr<bool> ParallelTaskAssigner::Run(HloModule* module) {
XLA_VLOG_LINES(2, "ParallelTaskAssigner ENTRY");
XLA_VLOG_LINES(3, module->ToString());
-
// Compute target parallel task counts for all instructions in 'module'.
HloToParallelTasks hlo_to_parallel_tasks;
ComputeTargetParallelTasks(module, &hlo_to_parallel_tasks);
@@ -230,6 +227,9 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
void ParallelTaskAssigner::ComputeTargetParallelTasks(
HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
+ ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
+ shape_size_function_, module);
+
// Compute parallel task counts for all instructions in 'module'.
for (auto* computation : module->computations()) {
if (computation->IsFusionComputation()) {
@@ -238,7 +238,7 @@ void ParallelTaskAssigner::ComputeTargetParallelTasks(
for (auto* instruction : computation->instructions()) {
// Query ParallelTaskAssignment for target parallel task count.
const int64 target_parallel_task_count =
- parallel_task_assignment_.GetTargetParallelTaskCount(instruction);
+ parallel_task_assignment.GetTargetParallelTaskCount(instruction);
if (target_parallel_task_count > 1) {
hlo_to_parallel_tasks->insert(
{instruction, target_parallel_task_count});