aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc41
1 files changed, 24 insertions, 17 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
index 09ef62c87f..8786bb6262 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_layout_assignment.cc
@@ -52,31 +52,38 @@ HeuristicLayoutAssignment(const HloInstruction* instr,
// W <=> X
//
// Therefore kOutputInputYX and kBatchDepthYX mean NCHW.
+ //
+ // If you have trouble keeping these straight, consider that all that matters
+ // is the location of the channel dim: Is it major (NCHW), or minor (NHWC)?
+
+ constexpr auto kAllNCHW =
+ std::make_tuple(DataLayout::kBatchDepthYX, FilterLayout::kOutputInputYX,
+ DataLayout::kBatchDepthYX);
+ constexpr auto kAllNHWC =
+ std::make_tuple(DataLayout::kBatchYXDepth, FilterLayout::kOutputYXInput,
+ DataLayout::kBatchYXDepth);
- // As of today, our empirical evidence is that cudnn 7.0 is faster on V100 x
- // fp16 with the mostly-NHWC layout. The heuristic may change as cudnn version
- // changes, as well as the hardware updates.
+ // If we're not Volta or not fp16, the decision is easy: Use NCHW.
if (!(instr->operand(0)->shape().element_type() == xla::PrimitiveType::F16 &&
IsVoltaOrLater(*stream_executor))) {
- return std::make_tuple(DataLayout::kBatchDepthYX,
- FilterLayout::kOutputInputYX,
- DataLayout::kBatchDepthYX);
+ return kAllNCHW;
}
+
VLOG(2) << "Using heuristic to figure out layouts for " << instr->ToString();
- // For BackwardInput that has stride, full NHWC layouts run significantly
- // slower than (NHWC, NCHW, NCHW) or (NHWC, NCHW, NHWC).
- //
- // TODO(timshen): more closely compare (NHWC, NCHW, NCHW) and (NHWC, NCHW,
- // NHWC).
+
+ // Empirically we've found with Volta and cudnn 7 that backward-input convs
+ // with stride are significantly faster with input in NHWC and filter/output
+ // in NCHW.
if (instr->custom_call_target() == kCudnnConvBackwardInputCallTarget &&
window_util::HasStride(instr->window())) {
- return std::make_tuple(DataLayout::kBatchYXDepth,
- FilterLayout::kOutputInputYX,
- DataLayout::kBatchDepthYX);
+ return std::make_tuple(DataLayout::kBatchYXDepth, // NHWC
+ FilterLayout::kOutputInputYX, // NCHW
+ DataLayout::kBatchDepthYX // NCHW
+ );
}
- return std::make_tuple(DataLayout::kBatchYXDepth,
- FilterLayout::kOutputYXInput,
- DataLayout::kBatchYXDepth);
+
+ // For other Volta f16 convolutions, use NHWC.
+ return kAllNHWC;
}
// Adds layout constraints on the cudnn custom-call instruction. The layout