aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
diff options
context:
space:
mode:
authorGravatar Sanjoy Das <sanjoy@google.com>2018-05-09 19:39:58 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-05-09 19:42:49 -0700
commitbb8315f0cf066266647c6eacdf575ac8f5e9989e (patch)
tree3701a5004258519f0baa4420416008be22dc0114 /tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
parentf79dbc73c5b2c0debb916280e4436d98890ed03b (diff)
Don't call into Eigen unless the input and output tensors are aligned
We teach TargetMachineFeatures about the alignment required for Eigen GEMM and Conv and then pipe TargetMachineFeatures through the places that need to decide whether a dot or a conv needs to be lowered to a call to Eigen. I also had to fix a minor bug in our LLVM IR implementation for convolution. PiperOrigin-RevId: 196065557
Diffstat (limited to 'tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc')
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc13
1 files changed, 9 insertions, 4 deletions
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
index 47e8405ff2..63d0f7b95c 100644
--- a/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
@@ -104,7 +104,9 @@ class DefaultCostModel : public ParallelCostModel {
ParallelTaskAssignment::ParallelTaskAssignment(
const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module) {
+ const HloCostAnalysis::ShapeSizeFunction& shape_size, HloModule* module,
+ const TargetMachineFeatures* target_machine_features)
+ : target_machine_features_(*target_machine_features) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
@@ -139,8 +141,10 @@ int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
opcode == HloOpcode::kFft || opcode == HloOpcode::kInfeed ||
opcode == HloOpcode::kOutfeed || opcode == HloOpcode::kRng ||
(opcode == HloOpcode::kConvolution &&
- PotentiallyImplementedAsEigenConvolution(*instruction)) ||
- PotentiallyImplementedAsEigenDot(*instruction) ||
+ PotentiallyImplementedAsEigenConvolution(*instruction,
+ target_machine_features_)) ||
+ PotentiallyImplementedAsEigenDot(*instruction,
+ target_machine_features_) ||
(opcode == HloOpcode::kFusion &&
instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
ShapeUtil::IsTuple(instruction->shape())) {
@@ -231,7 +235,8 @@ bool ParallelTaskAssigner::AssignParallelTasksHelper(
void ParallelTaskAssigner::ComputeTargetParallelTasks(
HloModule* module, HloToParallelTasks* hlo_to_parallel_tasks) {
ParallelTaskAssignment parallel_task_assignment(max_parallelism_,
- shape_size_function_, module);
+ shape_size_function_, module,
+ &target_machine_features_);
// Compute parallel task counts for all instructions in 'module'.
for (auto* computation : module->computations()) {