aboutsummaryrefslogtreecommitdiffhomepage
path: root/configure.py
diff options
context:
space:
mode:
authorGravatar Ankur Taly <ataly@google.com>2018-02-16 18:22:55 -0800
committerGravatar TensorFlower Gardener <gardener@tensorflow.org>2018-02-16 18:27:19 -0800
commit0e6f39d1bd7fe8daa86944f6ab0dd94fbeb4962a (patch)
treeee0dabaff4147ecc9bc92acd2a50dadbfd694f39 /configure.py
parent128572c316e6f2eb6346f920314ef98e88e75069 (diff)
Merge changes from github.
PiperOrigin-RevId: 186073337
Diffstat (limited to 'configure.py')
-rw-r--r--configure.py28
1 files changed, 28 insertions, 0 deletions
diff --git a/configure.py b/configure.py
index 27519b4aba..6b1fa7f1a8 100644
--- a/configure.py
+++ b/configure.py
@@ -827,6 +827,28 @@ def set_gcc_host_compiler_path(environ_cp):
write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path)
+def reformat_version_sequence(version_str, sequence_count):
+ """Reformat the version string to have the given number of sequences.
+
+ For example:
+ Given (7, 2) -> 7.0
+ (7.0.1, 2) -> 7.0
+ (5, 1) -> 5
+ (5.0.3.2, 1) -> 5
+
+ Args:
+ version_str: String, the version string.
+ sequence_count: int, an integer.
+ Returns:
+ string, reformatted version string.
+ """
+ v = version_str.split('.')
+ if len(v) < sequence_count:
+ v = v + (['0'] * (sequence_count - len(v)))
+
+ return '.'.join(v[:sequence_count])
+
+
def set_tf_cuda_version(environ_cp):
"""Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION."""
ask_cuda_version = (
@@ -837,6 +859,7 @@ def set_tf_cuda_version(environ_cp):
# Configure the Cuda SDK version to use.
tf_cuda_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION)
+ tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2)
# Find out where the CUDA toolkit is installed
default_cuda_path = _DEFAULT_CUDA_PATH
@@ -893,6 +916,7 @@ def set_tf_cudnn_version(environ_cp):
tf_cudnn_version = get_from_env_or_user_or_default(
environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version,
_DEFAULT_CUDNN_VERSION)
+ tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1)
default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH')
ask_cudnn_path = (r'Please specify the location where cuDNN %s library is '
@@ -1400,6 +1424,10 @@ def main():
if is_linux():
set_tf_tensorrt_install_path(environ_cp)
set_tf_cuda_compute_capabilities(environ_cp)
+ if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get(
+ 'LD_LIBRARY_PATH') != '1':
+ write_action_env_to_bazelrc('LD_LIBRARY_PATH',
+ environ_cp.get('LD_LIBRARY_PATH'))
set_tf_cuda_clang(environ_cp)
if environ_cp.get('TF_CUDA_CLANG') == '1':