aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure3
-rw-r--r--configure.py65
2 files changed, 32 insertions, 36 deletions
diff --git a/configure b/configure
index 9c21d2b03a..66b66ba54e 100755
--- a/configure
+++ b/configure
@@ -8,7 +8,8 @@ if [ -z "$PYTHON_BIN_PATH" ]; then
fi
# Set all env variables
-"$PYTHON_BIN_PATH" configure.py
+CONFIGURE_DIR=$(dirname "$0")
+"$PYTHON_BIN_PATH" "${CONFIGURE_DIR}/configure.py" "$@"
echo "Configuration finished"
diff --git a/configure.py b/configure.py
index 60f144f315..f77a048d86 100644
--- a/configure.py
+++ b/configure.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import argparse
import errno
import os
import platform
@@ -32,10 +33,6 @@ except ImportError:
from distutils.spawn import find_executable as which
# pylint: enable=g-import-not-at-top
-_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)),
- '.tf_configure.bazelrc')
-_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)),
- 'WORKSPACE')
_DEFAULT_CUDA_VERSION = '9.0'
_DEFAULT_CUDNN_VERSION = '7'
_DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
@@ -51,6 +48,11 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
+_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__))
+_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
+_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
+_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE')
+
class UserInputError(Exception):
pass
@@ -119,22 +121,6 @@ def sed_in_place(filename, old, new):
f.write(newdata)
-def remove_line_with(filename, token):
- """Remove lines that contain token from file.
-
- Args:
- filename: string for filename.
- token: string token to check if to remove a line from file or not.
- """
- with open(filename, 'r') as f:
- filedata = f.read()
-
- with open(filename, 'w') as f:
- for line in filedata.strip().split('\n'):
- if token not in line:
- f.write(line + '\n')
-
-
def write_to_bazelrc(line):
with open(_TF_BAZELRC, 'a') as f:
f.write(line + '\n')
@@ -245,25 +231,26 @@ def setup_python(environ_cp):
environ_cp['PYTHON_BIN_PATH'] = python_bin_path
# Write tools/python_bin_path.sh
- with open('tools/python_bin_path.sh', 'w') as f:
+ with open(os.path.join(
+ _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f:
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
-def reset_tf_configure_bazelrc():
+def reset_tf_configure_bazelrc(workspace_path):
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
+ bazelrc_path = os.path.join(workspace_path, '.bazelrc')
- home = os.path.expanduser('~')
- if not os.path.exists('.bazelrc'):
- if os.path.exists(os.path.join(home, '.bazelrc')):
- with open('.bazelrc', 'a') as f:
- f.write('import %s/.bazelrc\n' % home.replace('\\', '/'))
- else:
- open('.bazelrc', 'w').close()
-
- remove_line_with('.bazelrc', 'tf_configure')
- with open('.bazelrc', 'a') as f:
- f.write('import %workspace%/.tf_configure.bazelrc\n')
+ data = []
+ if os.path.exists(bazelrc_path):
+ with open(bazelrc_path, 'r') as f:
+ data = f.read().splitlines()
+ with open(bazelrc_path, 'w') as f:
+ for l in data:
+ if _TF_BAZELRC_FILENAME in l:
+ continue
+ f.write('%s\n' % l)
+ f.write('import %s\n' % _TF_BAZELRC)
def cleanup_makefile():
@@ -271,7 +258,8 @@ def cleanup_makefile():
These files could interfere with Bazel parsing.
"""
- makefile_download_dir = 'tensorflow/contrib/makefile/downloads'
+ makefile_download_dir = os.path.join(
+ _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads')
if os.path.isdir(makefile_download_dir):
for root, _, filenames in os.walk(makefile_download_dir):
for f in filenames:
@@ -1373,13 +1361,20 @@ def config_info_line(name, help_text):
def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--workspace",
+ type=str,
+ default=_TF_WORKSPACE_ROOT,
+ help="The absolute path to your active Bazel workspace.")
+ args = parser.parse_args()
+
# Make a copy of os.environ to be clear when functions and getting and setting
# environment variables.
environ_cp = dict(os.environ)
check_bazel_version('0.5.4')
- reset_tf_configure_bazelrc()
+ reset_tf_configure_bazelrc(args.workspace)
cleanup_makefile()
setup_python(environ_cp)