aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/compiler/xla
diff options
context:
space:
mode:
authorGravatar Justin Lebar <jlebar@google.com>2018-10-04 18:46:53 -0700
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-10-04 18:50:44 -0700
commit83ff640fa5026b8bd3cb9c2ceff9e99e8e03823a (patch)
tree8b9289829b524364446c216e49628e1c96a56a0b /tensorflow/compiler/xla
parent4a00f2fc6514ad5ee60ab0a9645863fdf263499f (diff)
[XLA:GPU] Fix old-ptxas-version detection logic.
This was completely broken for CUDA versions > 9 and resulted in spurious warnings. Reported in #22706#issuecomment-426861394 -- thank you! PiperOrigin-RevId: 215841354
Diffstat (limited to 'tensorflow/compiler/xla')
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc2
1 files changed, 1 insertions, 1 deletions
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
index b4ae2e42c7..50e47542c4 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc
@@ -401,7 +401,7 @@ void WarnIfBadPtxasVersion(const string& ptxas_path) {
"prefers >= 9.2.88). Compilation of XLA kernels below will likely "
"fail.\n\nYou do not need to update CUDA; cherry-picking the ptxas "
"binary is sufficient.";
- } else if ((vmaj < 9 || vmin < 2 || vdot < 88)) {
+ } else if (std::make_tuple(vmaj, vmin, vdot) < std::make_tuple(9, 2, 88)) {
LOG(WARNING)
<< "*** WARNING *** You are using ptxas " << vmaj << "." << vmin << "."
<< vdot