diff options
author | Ankur Taly <ataly@google.com> | 2018-02-16 18:22:55 -0800 |
---|---|---|
committer | TensorFlower Gardener <gardener@tensorflow.org> | 2018-02-16 18:27:19 -0800 |
commit | 0e6f39d1bd7fe8daa86944f6ab0dd94fbeb4962a (patch) | |
tree | ee0dabaff4147ecc9bc92acd2a50dadbfd694f39 /configure.py | |
parent | 128572c316e6f2eb6346f920314ef98e88e75069 (diff) |
Merge changes from github.
PiperOrigin-RevId: 186073337
Diffstat (limited to 'configure.py')
-rw-r--r-- | configure.py | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/configure.py b/configure.py index 27519b4aba..6b1fa7f1a8 100644 --- a/configure.py +++ b/configure.py @@ -827,6 +827,28 @@ def set_gcc_host_compiler_path(environ_cp): write_action_env_to_bazelrc('GCC_HOST_COMPILER_PATH', gcc_host_compiler_path) +def reformat_version_sequence(version_str, sequence_count): + """Reformat the version string to have the given number of sequences. + + For example: + Given (7, 2) -> 7.0 + (7.0.1, 2) -> 7.0 + (5, 1) -> 5 + (5.0.3.2, 1) -> 5 + + Args: + version_str: String, the version string. + sequence_count: int, an integer. + Returns: + string, reformatted version string. + """ + v = version_str.split('.') + if len(v) < sequence_count: + v = v + (['0'] * (sequence_count - len(v))) + + return '.'.join(v[:sequence_count]) + + def set_tf_cuda_version(environ_cp): """Set CUDA_TOOLKIT_PATH and TF_CUDA_VERSION.""" ask_cuda_version = ( @@ -837,6 +859,7 @@ def set_tf_cuda_version(environ_cp): # Configure the Cuda SDK version to use. tf_cuda_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDA_VERSION', ask_cuda_version, _DEFAULT_CUDA_VERSION) + tf_cuda_version = reformat_version_sequence(str(tf_cuda_version), 2) # Find out where the CUDA toolkit is installed default_cuda_path = _DEFAULT_CUDA_PATH @@ -893,6 +916,7 @@ def set_tf_cudnn_version(environ_cp): tf_cudnn_version = get_from_env_or_user_or_default( environ_cp, 'TF_CUDNN_VERSION', ask_cudnn_version, _DEFAULT_CUDNN_VERSION) + tf_cudnn_version = reformat_version_sequence(str(tf_cudnn_version), 1) default_cudnn_path = environ_cp.get('CUDA_TOOLKIT_PATH') ask_cudnn_path = (r'Please specify the location where cuDNN %s library is ' @@ -1400,6 +1424,10 @@ def main(): if is_linux(): set_tf_tensorrt_install_path(environ_cp) set_tf_cuda_compute_capabilities(environ_cp) + if 'LD_LIBRARY_PATH' in environ_cp and environ_cp.get( + 'LD_LIBRARY_PATH') != '1': + write_action_env_to_bazelrc('LD_LIBRARY_PATH', + environ_cp.get('LD_LIBRARY_PATH')) set_tf_cuda_clang(environ_cp) if environ_cp.get('TF_CUDA_CLANG') == '1': |