diff options
Diffstat (limited to 'tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc')
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc index dfdcf1875d..0b3b429710 100644 --- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc @@ -40,6 +40,7 @@ limitations under the License. #include "tensorflow/compiler/xla/service/gpu/cudnn_batchnorm_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_algorithm_picker.h" #include "tensorflow/compiler/xla/service/gpu/cudnn_convolution_rewriter.h" +#include "tensorflow/compiler/xla/service/gpu/cudnn_fused_convolution_rewriter.h" #include "tensorflow/compiler/xla/service/gpu/fusion_merger.h" #include "tensorflow/compiler/xla/service/gpu/gpu_constants.h" #include "tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.h" @@ -208,6 +209,7 @@ Status OptimizeHloModule(HloModule* hlo_module, se::StreamExecutor* stream_exec, pipeline.AddInvariantChecker<HloVerifier>(/*layout_sensitive=*/false, /*allow_mixed_precision=*/false); pipeline.AddPass<CudnnConvolutionRewriter>(); + pipeline.AddPass<CudnnFusedConvolutionRewriter>(); pipeline.AddPass<PadInsertion>(); if (IsVoltaOrLater(*stream_exec)) { pipeline.AddPass<PadForTensorCores>(); @@ -402,7 +404,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) { LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "." << vdot - << ", which older than 9.2.88. ptxas 9.x before 9.2.88 is known to " + << ", which is older than 9.2.88. ptxas 9.x before 9.2.88 is known to " "miscompile XLA code, leading to incorrect results or " "invalid-address errors.\n\nYou do not need to update to CUDA " "9.2.88; cherry-picking the ptxas binary is sufficient."; |