diff options
Diffstat (limited to 'configure.py')
-rw-r--r-- | configure.py | 82 |
1 files changed, 48 insertions, 34 deletions
diff --git a/configure.py b/configure.py index ada342a50a..31a83b4a15 100644 --- a/configure.py +++ b/configure.py @@ -943,6 +943,35 @@ def set_tf_cudnn_version(environ_cp): write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version) +def is_cuda_compatible(lib, cuda_ver, cudnn_ver): + """Check compatibility between given library and cudnn/cudart libraries.""" + ldd_bin = which('ldd') or '/usr/bin/ldd' + ldd_out = run_shell([ldd_bin, lib], True) + ldd_out = ldd_out.split(os.linesep) + cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') + cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') + cudnn = None + cudart = None + cudnn_ok = True # assume no cudnn dependency by default + cuda_ok = True # assume no cuda dependency by default + for line in ldd_out: + if 'libcudnn.so' in line: + cudnn = cudnn_pattern.search(line) + cudnn_ok = False + elif 'libcudart.so' in line: + cudart = cuda_pattern.search(line) + cuda_ok = False + if cudnn and len(cudnn.group(1)): + cudnn = convert_version_to_int(cudnn.group(1)) + if cudart and len(cudart.group(1)): + cudart = convert_version_to_int(cudart.group(1)) + if cudnn is not None: + cudnn_ok = (cudnn == cudnn_ver) + if cudart is not None: + cuda_ok = (cudart == cuda_ver) + return cudnn_ok and cuda_ok + + def set_tf_tensorrt_install_path(environ_cp): """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION. @@ -959,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp): raise ValueError('Currently TensorRT is only supported on Linux platform.') # Ask user whether to add TensorRT support. - if str(int(get_var( - environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1': + if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', + False))) != '1': return for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS): @@ -973,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp): # Result returned from "read" will be used unexpanded. That make "~" # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) + trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path)) def find_libs(search_path): """Search for libnvinfer.so in "search_path".""" fl = set() if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path, x)) - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) + fl.update([ + os.path.realpath(os.path.join(search_path, x)) + for x in os.listdir(search_path) + if 'libnvinfer.so' in x + ]) return fl possible_files = find_libs(trt_install_path) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib'))) possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64'))) - - def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver): - """Check the compatibility between tensorrt and cudnn/cudart libraries.""" - ldd_bin = which('ldd') or '/usr/bin/ldd' - ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep) - cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$') - cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$') - cudnn = None - cudart = None - for line in ldd_out: - if 'libcudnn.so' in line: - cudnn = cudnn_pattern.search(line) - elif 'libcudart.so' in line: - cudart = cuda_pattern.search(line) - if cudnn and len(cudnn.group(1)): - cudnn = convert_version_to_int(cudnn.group(1)) - if cudart and len(cudart.group(1)): - cudart = convert_version_to_int(cudart.group(1)) - return (cudnn == cudnn_ver) and (cudart == cuda_ver) - cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION']) cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$') highest_ver = [0, None, None] for lib_file in possible_files: - if is_compatible(lib_file, cuda_ver, cudnn_ver): + if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver): matches = nvinfer_pattern.search(lib_file) if len(matches.groups()) == 0: continue @@ -1029,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp): # Try another alternative from ldconfig. ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' ldconfig_output = run_shell([ldconfig_bin, '-p']) - search_result = re.search( - '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output) + search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)', + ldconfig_output) if search_result: libnvinfer_path_from_ldconfig = search_result.group(2) if os.path.exists(libnvinfer_path_from_ldconfig): - if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver): + if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver, + cudnn_ver): trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) tf_tensorrt_version = search_result.group(1) break @@ -1122,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp): nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path) nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h') - if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path): + nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt') + if os.path.exists(nccl_lib_path) and os.path.exists( + nccl_hdr_path) and os.path.exists(nccl_license_path): # Set NCCL_INSTALL_PATH environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path) @@ -1435,7 +1449,7 @@ def main(): setup_python(environ_cp) if is_windows(): - environ_cp['TF_NEED_S3'] = '0' + environ_cp['TF_NEED_AWS'] = '0' environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' @@ -1459,8 +1473,8 @@ def main(): 'with_gcp_support', True, 'gcp') set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System', 'with_hdfs_support', True, 'hdfs') - set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', - 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform', + 'with_aws_support', True, 'aws') set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', 'with_kafka_support', True, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', |