aboutsummaryrefslogtreecommitdiffhomepage
path: root/configure.py
diff options
context:
space:
mode:
Diffstat (limited to 'configure.py')
-rw-r--r--configure.py82
1 files changed, 48 insertions, 34 deletions
diff --git a/configure.py b/configure.py
index ada342a50a..31a83b4a15 100644
--- a/configure.py
+++ b/configure.py
@@ -943,6 +943,35 @@ def set_tf_cudnn_version(environ_cp):
write_action_env_to_bazelrc('TF_CUDNN_VERSION', tf_cudnn_version)
+def is_cuda_compatible(lib, cuda_ver, cudnn_ver):
+ """Check compatibility between given library and cudnn/cudart libraries."""
+ ldd_bin = which('ldd') or '/usr/bin/ldd'
+ ldd_out = run_shell([ldd_bin, lib], True)
+ ldd_out = ldd_out.split(os.linesep)
+ cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
+ cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
+ cudnn = None
+ cudart = None
+ cudnn_ok = True # assume no cudnn dependency by default
+ cuda_ok = True # assume no cuda dependency by default
+ for line in ldd_out:
+ if 'libcudnn.so' in line:
+ cudnn = cudnn_pattern.search(line)
+ cudnn_ok = False
+ elif 'libcudart.so' in line:
+ cudart = cuda_pattern.search(line)
+ cuda_ok = False
+ if cudnn and len(cudnn.group(1)):
+ cudnn = convert_version_to_int(cudnn.group(1))
+ if cudart and len(cudart.group(1)):
+ cudart = convert_version_to_int(cudart.group(1))
+ if cudnn is not None:
+ cudnn_ok = (cudnn == cudnn_ver)
+ if cudart is not None:
+ cuda_ok = (cudart == cuda_ver)
+ return cudnn_ok and cuda_ok
+
+
def set_tf_tensorrt_install_path(environ_cp):
"""Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.
@@ -959,8 +988,8 @@ def set_tf_tensorrt_install_path(environ_cp):
raise ValueError('Currently TensorRT is only supported on Linux platform.')
# Ask user whether to add TensorRT support.
- if str(int(get_var(
- environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False))) != '1':
+ if str(int(get_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT',
+ False))) != '1':
return
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
@@ -973,47 +1002,29 @@ def set_tf_tensorrt_install_path(environ_cp):
# Result returned from "read" will be used unexpanded. That make "~"
# unusable. Going through one more level of expansion to handle that.
- trt_install_path = os.path.realpath(
- os.path.expanduser(trt_install_path))
+ trt_install_path = os.path.realpath(os.path.expanduser(trt_install_path))
def find_libs(search_path):
"""Search for libnvinfer.so in "search_path"."""
fl = set()
if os.path.exists(search_path) and os.path.isdir(search_path):
- fl.update([os.path.realpath(os.path.join(search_path, x))
- for x in os.listdir(search_path) if 'libnvinfer.so' in x])
+ fl.update([
+ os.path.realpath(os.path.join(search_path, x))
+ for x in os.listdir(search_path)
+ if 'libnvinfer.so' in x
+ ])
return fl
possible_files = find_libs(trt_install_path)
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib')))
possible_files.update(find_libs(os.path.join(trt_install_path, 'lib64')))
-
- def is_compatible(tensorrt_lib, cuda_ver, cudnn_ver):
- """Check the compatibility between tensorrt and cudnn/cudart libraries."""
- ldd_bin = which('ldd') or '/usr/bin/ldd'
- ldd_out = run_shell([ldd_bin, tensorrt_lib]).split(os.linesep)
- cudnn_pattern = re.compile('.*libcudnn.so\\.?(.*) =>.*$')
- cuda_pattern = re.compile('.*libcudart.so\\.?(.*) =>.*$')
- cudnn = None
- cudart = None
- for line in ldd_out:
- if 'libcudnn.so' in line:
- cudnn = cudnn_pattern.search(line)
- elif 'libcudart.so' in line:
- cudart = cuda_pattern.search(line)
- if cudnn and len(cudnn.group(1)):
- cudnn = convert_version_to_int(cudnn.group(1))
- if cudart and len(cudart.group(1)):
- cudart = convert_version_to_int(cudart.group(1))
- return (cudnn == cudnn_ver) and (cudart == cuda_ver)
-
cuda_ver = convert_version_to_int(environ_cp['TF_CUDA_VERSION'])
cudnn_ver = convert_version_to_int(environ_cp['TF_CUDNN_VERSION'])
nvinfer_pattern = re.compile('.*libnvinfer.so.?(.*)$')
highest_ver = [0, None, None]
for lib_file in possible_files:
- if is_compatible(lib_file, cuda_ver, cudnn_ver):
+ if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
if len(matches.groups()) == 0:
continue
@@ -1029,12 +1040,13 @@ def set_tf_tensorrt_install_path(environ_cp):
# Try another alternative from ldconfig.
ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
ldconfig_output = run_shell([ldconfig_bin, '-p'])
- search_result = re.search(
- '.*libnvinfer.so\\.?([0-9.]*).* => (.*)', ldconfig_output)
+ search_result = re.search('.*libnvinfer.so\\.?([0-9.]*).* => (.*)',
+ ldconfig_output)
if search_result:
libnvinfer_path_from_ldconfig = search_result.group(2)
if os.path.exists(libnvinfer_path_from_ldconfig):
- if is_compatible(libnvinfer_path_from_ldconfig, cuda_ver, cudnn_ver):
+ if is_cuda_compatible(libnvinfer_path_from_ldconfig, cuda_ver,
+ cudnn_ver):
trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig)
tf_tensorrt_version = search_result.group(1)
break
@@ -1122,7 +1134,9 @@ def set_tf_nccl_install_path(environ_cp):
nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
- if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+ nccl_license_path = os.path.join(nccl_install_path, 'NCCL-SLA.txt')
+ if os.path.exists(nccl_lib_path) and os.path.exists(
+ nccl_hdr_path) and os.path.exists(nccl_license_path):
# Set NCCL_INSTALL_PATH
environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
@@ -1435,7 +1449,7 @@ def main():
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_S3'] = '0'
+ environ_cp['TF_NEED_AWS'] = '0'
environ_cp['TF_NEED_GCP'] = '0'
environ_cp['TF_NEED_HDFS'] = '0'
environ_cp['TF_NEED_JEMALLOC'] = '0'
@@ -1459,8 +1473,8 @@ def main():
'with_gcp_support', True, 'gcp')
set_build_var(environ_cp, 'TF_NEED_HDFS', 'Hadoop File System',
'with_hdfs_support', True, 'hdfs')
- set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System',
- 'with_s3_support', True, 's3')
+ set_build_var(environ_cp, 'TF_NEED_AWS', 'Amazon AWS Platform',
+ 'with_aws_support', True, 'aws')
set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform',
'with_kafka_support', True, 'kafka')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',