diff options
author | 2018-01-08 16:19:30 -0800 | |
---|---|---|
committer | 2018-01-08 16:23:23 -0800 | |
commit | fffcacb1d0537a353c4e9303f98d1e9b2e83bf23 (patch) | |
tree | da0bb8901c10861e5de3801be0a45f9555fbaaae | |
parent | 5976ab9b91ee6e236335ba4a322f5a514b29da7f (diff) |
[XLA:GPU] Warn if ptxas or the driver JIT has known bugs.
We try to compile ptx->sass using ptxas, and fall back to the driver JIT
only if this fails. So we only warn about a bad driver JIT version if
we actually use the driver to compile our ptx.
This change also quiets a LOG(INFO) to a VLOG(1) in cuda_diagnostics.cc
so we don't spit out the full contents of the driver file when we query
it.
PiperOrigin-RevId: 181235275
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/BUILD | 1 | ||||
-rw-r--r-- | tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 96 | ||||
-rw-r--r-- | tensorflow/stream_executor/cuda/cuda_diagnostics.cc | 4 |
3 files changed, 99 insertions, 2 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD index e4832b2ee6..a86d3583a6 100644 --- a/tensorflow/compiler/xla/service/gpu/BUILD +++ b/tensorflow/compiler/xla/service/gpu/BUILD @@ -481,6 +481,7 @@ cc_library( "//tensorflow/core:cuda_libdevice_path", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", + "//tensorflow/core:regexp_internal", "//tensorflow/core:stream_executor_no_cuda", "@llvm//:core", "@llvm//:support", diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 234f06fe2d..9f34866ff5 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -18,6 +18,7 @@ limitations under the License. #include <stdlib.h> #include <atomic> #include <functional> +#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex. #include <utility> #include "llvm/IR/DiagnosticInfo.h" @@ -77,9 +78,11 @@ limitations under the License. #include "tensorflow/core/platform/cuda_libdevice_path.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/regexp.h" #include "tensorflow/core/platform/stream_executor_no_cuda.h" #include "tensorflow/core/platform/subprocess.h" #include "tensorflow/core/platform/tracing.h" +#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h" namespace se = ::perftools::gputools; @@ -241,6 +244,93 @@ tensorflow::Status PrepareHloModuleForIrEmitting( return pipeline.Run(hlo_module).status(); } +// Prints a warning if the ptxas at ptxas_path has known bugs. +// +// Only prints a warning the first time it's called for a particular value of +// ptxas_path. +void WarnIfBadPtxasVersion(const string& ptxas_path) { + static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED); + static std::unordered_set<string>* seen_ptxas_paths GUARDED_BY(mu) = + new std::unordered_set<string>(); + + tensorflow::mutex_lock lock(mu); + if (!seen_ptxas_paths->insert(ptxas_path).second) { + // Already checked this ptx binary, nothing to do. + return; + } + + tensorflow::SubProcess ptxas; + ptxas.SetProgram(ptxas_path, {ptxas_path, "--version"}); + ptxas.SetChannelAction(tensorflow::CHAN_STDOUT, tensorflow::ACTION_PIPE); + if (!ptxas.Start()) { + LOG(WARNING) << "Couldn't invoke " << ptxas_path << " --version"; + return; + } + + string out; + int exit_code = ptxas.Communicate(/*stdin_input=*/nullptr, &out, + /*stderr_output=*/nullptr); + if (exit_code != 0) { + LOG(WARNING) << "Running " << ptxas_path << " --version returned " + << exit_code; + return; + } + + int64 vmaj, vmin, vdot; + string vmaj_str, vmin_str, vdot_str; + using tensorflow::strings::safe_strto64; + if (!RE2::PartialMatch(out, R"(\bV(\d+)\.(\d+)\.(\d+)\b)", &vmaj_str, + &vmin_str, &vdot_str) || + !safe_strto64(vmaj_str, &vmaj) || !safe_strto64(vmin_str, &vmin) || + !safe_strto64(vdot_str, &vdot)) { + LOG(WARNING) << "Couldn't parse ptxas version in output of " << ptxas_path + << " --version:\n" + << out; + return; + } + + // ptxas 9.0 before 9.0.276 miscompiles some address calculations with large + // offsets (e.g. "load ptr + large_constant"), b/70245379. + if (vmaj == 9 && vmin == 0 && vdot < 276) { + LOG(WARNING) << "*** WARNING *** You are using ptxas " << vmaj << "." + << vmin << "." << vdot + << ", which is in range [9.0.0, 9.0.276). These versions are " + "known to miscompile XLA code, leading to incorrect " + "results or invalid-address errors."; + } +} + +// Prints a warning if the ptx->sass JIT in the driver has known bugs. +// +// Using such a driver only a problem if we fail to use ptxas to compile our ptx +// and have to use the driver instead, so you should only call this function if +// we're going to use the driver JIT. +// +// Only prints a warning the first time it's called. +void WarnIfBadDriverJITVersion() { + static std::once_flag run_once; + std::call_once(run_once, [] { + auto version_or_status = se::cuda::Diagnostician::FindKernelDriverVersion(); + if (!version_or_status.ok()) { + LOG(WARNING) << "Couldn't read CUDA driver version."; + return; + } + se::cuda::DriverVersion version = version_or_status.ValueOrDie(); + + // The driver JIT in 384 before 384.108 miscompiles some address + // calculations with large offsets (e.g. "load ptr + large_constant"), + // b/70245379. + if (std::get<0>(version) == 384 && std::get<1>(version) < 108) { + LOG(WARNING) + << "*** WARNING *** Invoking the PTX->SASS JIT from driver version " + << se::cuda::DriverVersionToString(version) + << ", which is in range [384.0.0, 384.108.0). These versions are " + "known to miscompile XLA code, leading to incorrect results or " + "invalid-address errors."; + } + }); +} + // Compiles the given PTX string using ptxas and returns the resulting machine // code (i.e. a cubin) as a byte array. StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, @@ -252,6 +342,8 @@ StatusOr<std::vector<uint8>> CompilePtx(const string& ptx, int cc_major, auto env = tensorflow::Env::Default(); TF_RETURN_IF_ERROR(env->FileExists(ptxas_path)); + WarnIfBadPtxasVersion(ptxas_path); + // Write ptx into a temporary file. string ptx_path; if (!env->LocalTempFilename(&ptx_path)) { @@ -555,6 +647,10 @@ std::vector<uint8> GpuCompiler::CompilePtxOrGetCachedResult(const string& ptx, "GPU driver compile the ptx. " << maybe_cubin.status(); } + + // We're going to use the driver to JIT our PTX->SASS, so warn if + // the JIT in the driver has known bugs. + WarnIfBadDriverJITVersion(); } } cache_value->compilation_done = true; diff --git a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc index 00506fa54b..d5a3ee32e9 100644 --- a/tensorflow/stream_executor/cuda/cuda_diagnostics.cc +++ b/tensorflow/stream_executor/cuda/cuda_diagnostics.cc @@ -366,8 +366,8 @@ port::StatusOr<DriverVersion> Diagnostician::FindKernelDriverVersion() { contents[kContentsSize - 1] = '\0'; if (retcode != 0) { - LOG(INFO) << "driver version file contents: \"\"\"" << contents.begin() - << "\"\"\""; + VLOG(1) << "driver version file contents: \"\"\"" << contents.begin() + << "\"\"\""; fclose(driver_version_file); return FindKernelModuleVersion(contents.begin()); } |