aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xconfigure73
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc4
-rw-r--r--tensorflow/core/kernels/lrn_op.cc2
-rw-r--r--tensorflow/core/kernels/matmul_op.cc2
-rw-r--r--tensorflow/core/ops/ops.pbtxt36
-rw-r--r--tensorflow/core/platform/default/build_config/BUILD18
-rw-r--r--tensorflow/core/platform/default/gpu/BUILD6
-rw-r--r--tensorflow/core/platform/default/gpu/cupti_wrapper.h2
-rw-r--r--tensorflow/core/util/port.cc2
-rw-r--r--tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md5
-rw-r--r--tensorflow/g3doc/api_docs/python/train.md5
-rw-r--r--tensorflow/g3doc/get_started/os_setup.md9
-rw-r--r--tensorflow/python/ops/ctc_ops.py18
-rw-r--r--tensorflow/python/ops/variables.py3
-rw-r--r--tensorflow/python/training/optimizer.py6
-rw-r--r--tensorflow/stream_executor/BUILD9
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc4
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_driver.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_fft.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_helpers.h4
-rw-r--r--tensorflow/stream_executor/cuda/cuda_kernel.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_rng.cc2
-rw-r--r--tensorflow/stream_executor/dso_loader.cc60
-rw-r--r--tensorflow/stream_executor/dso_loader.h1
-rw-r--r--tensorflow/tensorboard/backend/handler.py5
-rw-r--r--tensorflow/tensorflow.bzl32
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.gpu3
-rw-r--r--tensorflow/tools/dist_test/README.md2
-rw-r--r--tensorflow/workspace.bzl3
-rw-r--r--third_party/gpus/BUILD0
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl42
-rw-r--r--third_party/gpus/crosstool/CROSSTOOL.tpl254
-rwxr-xr-xthird_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl315
-rw-r--r--third_party/gpus/cuda/BUILD224
-rw-r--r--third_party/gpus/cuda/BUILD.tpl172
-rw-r--r--third_party/gpus/cuda/build_defs.bzl.tpl14
-rw-r--r--third_party/gpus/cuda/cuda_config.h.tpl24
-rw-r--r--third_party/gpus/cuda/platform.bzl.tpl57
-rw-r--r--third_party/gpus/cuda_configure.bzl423
-rw-r--r--tools/bazel.rc.template2
41 files changed, 1469 insertions, 382 deletions
diff --git a/configure b/configure
index 9ab6ea6b1c..bcef37bd26 100755
--- a/configure
+++ b/configure
@@ -80,6 +80,7 @@ while [ "$TF_NEED_CUDA" == "" ]; do
esac
done
+export TF_NEED_CUDA
if [ "$TF_NEED_CUDA" == "0" ]; then
echo "Configuration finished"
exit
@@ -97,6 +98,7 @@ while true; do
fi
fi
if [ -e "$GCC_HOST_COMPILER_PATH" ]; then
+ export CC=$GCC_HOST_COMPILER_PATH
break
fi
echo "Invalid gcc path. ${GCC_HOST_COMPILER_PATH} cannot be found" 1>&2
@@ -107,7 +109,6 @@ while true; do
# Retry
done
-
# Find out where the CUDA toolkit is installed
OSNAME=`uname -s`
@@ -140,6 +141,8 @@ while true; do
fi
if [ -e "${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH}" ]; then
+ export CUDA_TOOLKIT_PATH
+ export CUDA_VERSION=$TF_CUDA_VERSION
break
fi
echo "Invalid path to CUDA $TF_CUDA_VERSION toolkit. ${CUDA_TOOLKIT_PATH}/${CUDA_RT_LIB_PATH} cannot be found"
@@ -200,13 +203,16 @@ while true; do
fi
if [ -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_ALT_PATH}" -o -e "$CUDNN_INSTALL_PATH/${CUDA_DNN_LIB_PATH}" ]; then
+ export CUDNN_VERSION=$TF_CUDNN_VERSION
+ export CUDNN_INSTALL_PATH
break
fi
if [ "$OSNAME" == "Linux" ]; then
CUDNN_PATH_FROM_LDCONFIG="$(ldconfig -p | sed -n 's/.*libcudnn.so .* => \(.*\)/\1/p')"
if [ -e "${CUDNN_PATH_FROM_LDCONFIG}${TF_CUDNN_EXT}" ]; then
- CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
+ export CUDNN_VERSION=$TF_CUDNN_VERSION
+ export CUDNN_INSTALL_PATH="$(dirname ${CUDNN_PATH_FROM_LDCONFIG})"
break
fi
fi
@@ -225,42 +231,11 @@ while true; do
CUDNN_INSTALL_PATH=""
done
-cat > third_party/gpus/cuda/cuda.config <<EOF
-# CUDA_TOOLKIT_PATH refers to the CUDA toolkit.
-CUDA_TOOLKIT_PATH="$CUDA_TOOLKIT_PATH"
-# CUDNN_INSTALL_PATH refers to the cuDNN toolkit. The cuDNN header and library
-# files can be either in this directory, or under include/ and lib64/
-# directories separately.
-CUDNN_INSTALL_PATH="$CUDNN_INSTALL_PATH"
-
-# The Cuda SDK version that should be used in this build (empty to use libcudart.so symlink)
-TF_CUDA_VERSION=$TF_CUDA_VERSION
-
-# The Cudnn version that should be used in this build
-TF_CUDNN_VERSION=$TF_CUDNN_VERSION
-EOF
-
-# Configure the gcc host compiler to use
-export WARNING=$DO_NOT_SUBMIT_WARNING
-perl -pi -e "s,CPU_COMPILER = \('.*'\),# \$ENV{WARNING}\nCPU_COMPILER = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
-perl -pi -e "s,GCC_HOST_COMPILER_PATH = \('.*'\),# \$ENV{WARNING}\nGCC_HOST_COMPILER_PATH = ('$GCC_HOST_COMPILER_PATH'),s" third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
-
-# Configure the platform name.
-perl -pi -e "s,PLATFORM = \".*\",PLATFORM = \"$OSNAME\",s" third_party/gpus/cuda/platform.bzl
-
-# Configure the Cuda toolkit version to work with.
-perl -pi -e "s,(GetCudaVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDA_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
-perl -pi -e "s,CUDA_VERSION = \"[0-9\.]*\",CUDA_VERSION = \"$TF_CUDA_VERSION\",s" third_party/gpus/cuda/platform.bzl
-
-# Configure the Cudnn version to work with.
-perl -pi -e "s,(GetCudnnVersion.*return )\"[0-9\.]*\",\1\"$TF_CUDNN_VERSION\",s" tensorflow/stream_executor/dso_loader.cc
-perl -pi -e "s,CUDNN_VERSION = \"[0-9\.]*\",CUDNN_VERSION = \"$TF_CUDNN_VERSION\",s" third_party/gpus/cuda/platform.bzl
-
-
# Configure the compute capabilities that TensorFlow builds for.
# Since Cuda toolkit is not backward-compatible, this is not guaranteed to work.
while true; do
fromuser=""
+ default_cuda_compute_capabilities="3.5,5.2"
if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
cat << EOF
Please specify a list of comma-separated Cuda compute capabilities you want to build with.
@@ -270,6 +245,9 @@ EOF
read -p "[Default is: \"3.5,5.2\"]: " TF_CUDA_COMPUTE_CAPABILITIES
fromuser=1
fi
+ if [ -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
+ TF_CUDA_COMPUTE_CAPABILITIES=$default_cuda_compute_capabilities
+ fi
# Check whether all capabilities from the input is valid
COMPUTE_CAPABILITIES=${TF_CUDA_COMPUTE_CAPABILITIES//,/ }
ALL_VALID=1
@@ -285,34 +263,13 @@ EOF
exit 1
fi
else
+ export CUDA_COMPUTE_CAPABILITIES=$TF_CUDA_COMPUTE_CAPABILITIES
break
fi
TF_CUDA_COMPUTE_CAPABILITIES=""
done
-if [ ! -z "$TF_CUDA_COMPUTE_CAPABILITIES" ]; then
- export WARNING=$DO_NOT_SUBMIT_WARNING
- function CudaGenCodeOpts() {
- OUTPUT=""
- for CAPABILITY in $@; do
- OUTPUT=${OUTPUT}" \"${CAPABILITY}\", "
- done
- echo $OUTPUT
- }
- export CUDA_GEN_CODES_OPTS=$(CudaGenCodeOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
- perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\[).*?(\]),\n\1# $ENV{WARNING}\n\1\2$ENV{CUDA_GEN_CODES_OPTS}\3,s' third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc
- function CudaVersionOpts() {
- OUTPUT=""
- for CAPABILITY in $@; do
- OUTPUT=$OUTPUT"CudaVersion(\"${CAPABILITY}\"), "
- done
- echo $OUTPUT
- }
- export CUDA_VERSION_OPTS=$(CudaVersionOpts ${TF_CUDA_COMPUTE_CAPABILITIES//,/ })
- perl -pi -0 -e 's,\n( *)([^\n]*supported_cuda_compute_capabilities\s*=\s*\{).*?(\}),\n\1// $ENV{WARNING}\n\1\2$ENV{CUDA_VERSION_OPTS}\3,s' tensorflow/core/common_runtime/gpu/gpu_device.cc
-fi
-
-# Invoke the cuda_config.sh and set up the TensorFlow's canonical view of the Cuda libraries
-(cd third_party/gpus/cuda; ./cuda_config.sh;) || exit -1
+bazel clean --expunge
+bazel fetch //...
echo "Configuration finished"
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0ff64dd939..c08abc5689 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -785,10 +785,8 @@ struct CudaVersion {
int minor_part = -1;
};
-// "configure" uses the specific name to substitute the following string.
-// If you change it, make sure you modify "configure" as well.
std::vector<CudaVersion> supported_cuda_compute_capabilities = {
- CudaVersion("3.5"), CudaVersion("5.2")};
+ TF_CUDA_CAPABILITIES,};
std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
auto cuda_caps = supported_cuda_compute_capabilities;
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index 7b99aee6c6..3435486c95 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -31,7 +31,7 @@ limitations under the License.
#endif
#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/util/stream_executor_util.h"
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index 7852b90c73..03a6d29839 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/core/kernels/fill_functor.h"
#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 2583128d8d..bd5593a6a2 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -4681,6 +4681,42 @@ op {
description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the PNG-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the PNG-encoded image is transformed to match the requested number\nof color channels."
}
op {
+ name: "DecodeGif"
+ input_arg {
+ name: "contents"
+ description: "0-D. The GIF-encoded image."
+ type: DT_STRING
+ }
+ output_arg {
+ name: "image"
+ description: "3-D with shape `[height, width, channels]`."
+ type_attr: "dtype"
+ }
+ attr {
+ name: "channels"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ description: "Number of color channels for the decoded image."
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ default_value {
+ type: DT_UINT8
+ }
+ allowed_values {
+ list {
+ type: DT_UINT8
+ type: DT_UINT16
+ }
+ }
+ }
+ summary: "Decode a GIF-encoded image to a uint8 or uint16 tensor."
+ description: "The attr `channels` indicates the desired number of color channels for the\ndecoded image.\n\nAccepted values are:\n\n* 0: Use the number of channels in the GIF-encoded image.\n* 1: output a grayscale image.\n* 3: output an RGB image.\n* 4: output an RGBA image.\n\nIf needed, the GIF-encoded image is transformed to match the requested number\nof color channels."
+}
+op {
name: "DecodeRaw"
input_arg {
name: "bytes"
diff --git a/tensorflow/core/platform/default/build_config/BUILD b/tensorflow/core/platform/default/build_config/BUILD
index f372d2ef0d..158c42b5ad 100644
--- a/tensorflow/core/platform/default/build_config/BUILD
+++ b/tensorflow/core/platform/default/build_config/BUILD
@@ -9,7 +9,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "tf_copts")
load("//tensorflow:tensorflow.bzl", "tf_cuda_library")
-load("//third_party/gpus/cuda:platform.bzl", "cuda_library_path")
+load("@local_config_cuda//cuda:platform.bzl", "cuda_library_path")
cc_library(
name = "gtest",
@@ -32,7 +32,7 @@ tf_cuda_library(
deps = [
"//tensorflow/stream_executor",
] + select({
- "//third_party/gpus/cuda:darwin": ["IOKit"],
+ "@local_config_cuda//cuda:darwin": ["IOKit"],
"//conditions:default": [],
}),
)
@@ -91,20 +91,20 @@ filegroup(
cc_library(
name = "cuda",
data = [
- "//third_party/gpus/cuda:{}".format(cuda_library_path("cudart")),
+ "@local_config_cuda//cuda:{}".format(cuda_library_path("cudart")),
],
linkopts = select({
- "//third_party/gpus/cuda:darwin": [
- "-Wl,-rpath,third_party/gpus/cuda/lib",
- "-Wl,-rpath,third_party/gpus/cuda/extras/CUPTI/lib",
+ "@local_config_cuda//cuda:darwin": [
+ "-Wl,-rpath,../local_config_cuda/cuda/lib",
+ "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib",
],
"//conditions:default": [
- "-Wl,-rpath,third_party/gpus/cuda/lib64",
- "-Wl,-rpath,third_party/gpus/cuda/extras/CUPTI/lib64",
+ "-Wl,-rpath,../local_config_cuda/cuda/lib64",
+ "-Wl,-rpath,../local_config_cuda/cuda/extras/CUPTI/lib64",
],
}),
deps = [
- "//third_party/gpus/cuda:cudart",
+ "@local_config_cuda//cuda:cudart",
],
)
diff --git a/tensorflow/core/platform/default/gpu/BUILD b/tensorflow/core/platform/default/gpu/BUILD
index 93b6227848..6b0c919a89 100644
--- a/tensorflow/core/platform/default/gpu/BUILD
+++ b/tensorflow/core/platform/default/gpu/BUILD
@@ -15,9 +15,9 @@ tf_cuda_library(
copts = tf_copts(),
cuda_deps = [
"//tensorflow/core:stream_executor",
- "//third_party/gpus/cuda:cuda_headers",
- "//third_party/gpus/cuda:cupti_headers",
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_cuda//cuda:cupti_headers",
],
- data = ["//third_party/gpus/cuda:cupti_dsos"],
+ data = ["@local_config_cuda//cuda:cupti_dsos"],
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.h b/tensorflow/core/platform/default/gpu/cupti_wrapper.h
index 5829172c47..e482f8607f 100644
--- a/tensorflow/core/platform/default/gpu/cupti_wrapper.h
+++ b/tensorflow/core/platform/default/gpu/cupti_wrapper.h
@@ -21,7 +21,7 @@ limitations under the License.
#include <stddef.h>
#include <stdint.h>
-#include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
+#include "cuda/extras/CUPTI/include/cupti.h"
namespace perftools {
namespace gputools {
diff --git a/tensorflow/core/util/port.cc b/tensorflow/core/util/port.cc
index 42375770f4..d93b971f85 100644
--- a/tensorflow/core/util/port.cc
+++ b/tensorflow/core/util/port.cc
@@ -16,7 +16,7 @@ limitations under the License.
#include "tensorflow/core/util/port.h"
#if GOOGLE_CUDA
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuda.h"
#endif
namespace tensorflow {
diff --git a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md
index 02bb5b5ac4..ff14086bb4 100644
--- a/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md
+++ b/tensorflow/g3doc/api_docs/python/functions_and_classes/shard5/tf.train.Optimizer.md
@@ -184,9 +184,8 @@ applies gradients.
### Gating Gradients
-Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
-that controls the degree of parallelism during the application of the
-gradients.
+Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument
+that controls the degree of parallelism during the application of the gradients.
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
diff --git a/tensorflow/g3doc/api_docs/python/train.md b/tensorflow/g3doc/api_docs/python/train.md
index 94fa7a3ed5..3d0329ad3f 100644
--- a/tensorflow/g3doc/api_docs/python/train.md
+++ b/tensorflow/g3doc/api_docs/python/train.md
@@ -204,9 +204,8 @@ applies gradients.
### Gating Gradients
-Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
-that controls the degree of parallelism during the application of the
-gradients.
+Both `minimize()` and `compute_gradients()` accept a `gate_gradients` argument
+that controls the degree of parallelism during the application of the gradients.
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
diff --git a/tensorflow/g3doc/get_started/os_setup.md b/tensorflow/g3doc/get_started/os_setup.md
index d85ae429fa..53afca4ccb 100644
--- a/tensorflow/g3doc/get_started/os_setup.md
+++ b/tensorflow/g3doc/get_started/os_setup.md
@@ -7,11 +7,10 @@ github source.
The TensorFlow Python API supports Python 2.7 and Python 3.3+.
-The GPU version (Linux only) works best with Cuda Toolkit 7.5 and
-cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and
-cuDNN 6.5(v2), 7.0(v3), v5) only when installing from sources.
-Please see [Cuda installation](#optional-install-cuda-gpus-on-linux)
-for details.
+The GPU version (Linux & Mac OS X only) works best with Cuda Toolkit 7.5 and
+cuDNN v4. other versions are supported (Cuda toolkit >= 7.0 and cuDNN 6.5(v2),
+7.0(v3), v5) only when installing from sources. Please see [Cuda installation]
+(#optional-install-cuda-gpus-on-linux) for details.
## Overview
diff --git a/tensorflow/python/ops/ctc_ops.py b/tensorflow/python/ops/ctc_ops.py
index bab9dc0ef5..2f85112172 100644
--- a/tensorflow/python/ops/ctc_ops.py
+++ b/tensorflow/python/ops/ctc_ops.py
@@ -48,6 +48,18 @@ def ctc_loss(inputs, labels, sequence_length,
<= sequence_length(b) for all b.
```
+ Notes:
+
+ This class performs the softmax operation for you, so inputs should
+ be e.g. linear projections of outputs by an LSTM.
+
+ The `inputs` Tensor's innermost dimension size, `num_classes`, represents
+ `num_labels + 1` classes, where num_labels is the number of true labels, and
+ the largest value `(num_classes - 1)` is reserved for the blank label.
+
+ For example, for a vocabulary containing 3 labels `[a, b, c]`,
+ `num_classes = 4` and the labels indexing is `{a: 0, b: 1, c: 2, blank: 3}`.
+
Regarding the arguments `preprocess_collapse_repeated` and
`ctc_merge_repeated`:
@@ -84,10 +96,12 @@ def ctc_loss(inputs, labels, sequence_length,
Args:
inputs: 3-D `float` `Tensor` sized
- `[max_time x batch_size x num_classes]`. The logits.
+ `[max_time x batch_size x num_classes]`. The logits.
labels: An `int32` `SparseTensor`.
`labels.indices[i, :] == [b, t]` means `labels.values[i]` stores
- the id for (batch b, time t). See `core/ops/ctc_ops.cc` for more details.
+ the id for (batch b, time t).
+ `labels.values[i]` must take on values in `[0, num_labels)`.
+ See `core/ops/ctc_ops.cc` for more details.
sequence_length: 1-D `int32` vector, size `[batch_size]`.
The sequence lengths.
preprocess_collapse_repeated: Boolean. Default: False.
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index a720d7d26a..d3d78dad5f 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -1036,8 +1036,7 @@ def report_uninitialized_variables(var_list=None,
Returns:
A 1-D tensor containing names of the uninitialized variables, or an empty
- 1-D
- tensor if there are no variables or no uninitialized variables.
+ 1-D tensor if there are no variables or no uninitialized variables.
"""
if var_list is None:
var_list = all_variables() + local_variables()
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 2627f9fce0..deeec2f6e3 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -89,9 +89,9 @@ class Optimizer(object):
### Gating Gradients
- Both `minimize()` and `compute_gradients()` accept a `gate_gradient` argument
- that controls the degree of parallelism during the application of the
- gradients.
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
+ argument that controls the degree of parallelism during the application of
+ the gradients.
The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 24e8305d31..256b128750 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -27,9 +27,10 @@ cc_library(
]),
data = [
"//tensorflow/core:cuda",
- "//third_party/gpus/cuda:cublas",
- "//third_party/gpus/cuda:cudnn",
- "//third_party/gpus/cuda:cufft",
+ "@local_config_cuda//cuda:cublas",
+ "@local_config_cuda//cuda:cudnn",
+ "@local_config_cuda//cuda:cufft",
+ "@local_config_cuda//cuda:curand",
],
linkopts = [
"-ldl",
@@ -37,7 +38,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
- "//third_party/gpus/cuda:cuda_headers",
+ "@local_config_cuda//cuda:cuda_headers",
],
alwayslink = 1,
)
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index a9dd2953e5..e2611cd3d0 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -18,8 +18,8 @@ limitations under the License.
// cuda.h). This ensures that Eigen's Half.h does not attempt to make its own
// __half typedef if CUDA has already defined one (and conversely, that we do
// not include <cuda_fp16.h> after Half.h has made its typedef).
-#include "third_party/gpus/cuda/include/cuda.h"
-#include "third_party/gpus/cuda/include/cublas_v2.h"
+#include "cuda/include/cuda.h"
+#include "cuda/include/cublas_v2.h"
#if CUDA_VERSION >= 7050
#define EIGEN_HAS_CUDA_FP16
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 7fbafa3d7e..55535f9ce5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -39,7 +39,7 @@ limitations under the License.
#include "tensorflow/stream_executor/stream.h"
#include "tensorflow/stream_executor/stream_executor_pimpl.h"
// clang-format off
-#include "third_party/gpus/cuda/include/cudnn.h"
+#include "cuda/include/cudnn.h"
// clang-format on
namespace {
diff --git a/tensorflow/stream_executor/cuda/cuda_driver.h b/tensorflow/stream_executor/cuda/cuda_driver.h
index a5de5d0f59..ab118e5d40 100644
--- a/tensorflow/stream_executor/cuda/cuda_driver.h
+++ b/tensorflow/stream_executor/cuda/cuda_driver.h
@@ -25,7 +25,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/platform/port.h"
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuda.h"
namespace perftools {
namespace gputools {
diff --git a/tensorflow/stream_executor/cuda/cuda_fft.h b/tensorflow/stream_executor/cuda/cuda_fft.h
index 0c7aa34df3..95b3e8de63 100644
--- a/tensorflow/stream_executor/cuda/cuda_fft.h
+++ b/tensorflow/stream_executor/cuda/cuda_fft.h
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/stream_executor/fft.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/plugin_registry.h"
-#include "third_party/gpus/cuda/include/cufft.h"
+#include "cuda/include/cufft.h"
namespace perftools {
namespace gputools {
diff --git a/tensorflow/stream_executor/cuda/cuda_helpers.h b/tensorflow/stream_executor/cuda/cuda_helpers.h
index 7753866560..6a6134bf88 100644
--- a/tensorflow/stream_executor/cuda/cuda_helpers.h
+++ b/tensorflow/stream_executor/cuda/cuda_helpers.h
@@ -24,8 +24,8 @@ limitations under the License.
#include <stddef.h>
#include <complex>
-#include "third_party/gpus/cuda/include/cuComplex.h"
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuComplex.h"
+#include "cuda/include/cuda.h"
namespace perftools {
namespace gputools {
diff --git a/tensorflow/stream_executor/cuda/cuda_kernel.h b/tensorflow/stream_executor/cuda/cuda_kernel.h
index 412e7d9a40..88d29fddd0 100644
--- a/tensorflow/stream_executor/cuda/cuda_kernel.h
+++ b/tensorflow/stream_executor/cuda/cuda_kernel.h
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/casts.h"
#include "tensorflow/stream_executor/platform/port.h"
#include "tensorflow/stream_executor/platform/logging.h"
-#include "third_party/gpus/cuda/include/cuda.h"
+#include "cuda/include/cuda.h"
#ifdef PLATFORMS_GPUS_CUDA_DYNAMIC_LIBCUDA_DYNAMIC_LIBCUDA_H_
#error \
diff --git a/tensorflow/stream_executor/cuda/cuda_rng.cc b/tensorflow/stream_executor/cuda/cuda_rng.cc
index 334c6af970..367eba4d51 100644
--- a/tensorflow/stream_executor/cuda/cuda_rng.cc
+++ b/tensorflow/stream_executor/cuda/cuda_rng.cc
@@ -28,7 +28,7 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/rng.h"
-#include "third_party/gpus/cuda/include/curand.h"
+#include "cuda/include/curand.h"
// Formats curandStatus_t to output prettified values into a log stream.
std::ostream &operator<<(std::ostream &in, const curandStatus_t &status) {
diff --git a/tensorflow/stream_executor/dso_loader.cc b/tensorflow/stream_executor/dso_loader.cc
index cce31ef4dc..83f2cadfad 100644
--- a/tensorflow/stream_executor/dso_loader.cc
+++ b/tensorflow/stream_executor/dso_loader.cc
@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+// TODO(jhen): Replace hardcoded, platform specific path strings in GetXXXPath()
+// with a function in e.g. cuda.h.
+
#include "tensorflow/stream_executor/dso_loader.h"
#include <dlfcn.h>
@@ -32,19 +35,17 @@ limitations under the License.
#include "tensorflow/stream_executor/lib/stringprintf.h"
#include "tensorflow/stream_executor/platform/logging.h"
#include "tensorflow/stream_executor/platform/port.h"
-#include "tensorflow/stream_executor/lib/str_util.h"
namespace perftools {
namespace gputools {
namespace internal {
-// TensorFlow OSS configure uses the following lines to configure versions. For
-// any modifications of the format, please make sure the script still works.
-string GetCudaVersion() { return ""; }
-string GetCudnnVersion() { return ""; }
+string GetCudaVersion() { return TF_CUDA_VERSION; }
+string GetCudnnVersion() { return TF_CUDNN_VERSION; }
/* static */ port::Status DsoLoader::GetCublasDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("cublas", GetCudaVersion()),
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName(
+ "cublas", GetCudaVersion()),
GetCudaLibraryDirPath()),
dso_handle);
}
@@ -53,35 +54,38 @@ string GetCudnnVersion() { return ""; }
// libcudnn is versioned differently than the other libraries and may have a
// different version number than other CUDA libraries. See b/22397368 for
// some details about the complications surrounding this.
- return GetDsoHandle(
- FindDsoPath(tensorflow::internal::FormatLibraryFileName("cudnn", GetCudnnVersion()),
- GetCudaLibraryDirPath()),
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName(
+ "cudnn", GetCudnnVersion()),
+ GetCudaLibraryDirPath()),
dso_handle);
}
/* static */ port::Status DsoLoader::GetCufftDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("cufft", GetCudaVersion()),
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName(
+ "cufft", GetCudaVersion()),
GetCudaLibraryDirPath()),
dso_handle);
}
/* static */ port::Status DsoLoader::GetCurandDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("curand", GetCudaVersion()),
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName(
+ "curand", GetCudaVersion()),
GetCudaLibraryDirPath()),
dso_handle);
}
/* static */ port::Status DsoLoader::GetLibcudaDsoHandle(void** dso_handle) {
- return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName("cuda", "1"),
- GetCudaDriverLibraryPath()),
- dso_handle);
+ return GetDsoHandle(
+ FindDsoPath(tensorflow::internal::FormatLibraryFileName("cuda", "1"),
+ GetCudaDriverLibraryPath()),
+ dso_handle);
}
/* static */ port::Status DsoLoader::GetLibcuptiDsoHandle(void** dso_handle) {
- return GetDsoHandle(
- FindDsoPath(tensorflow::internal::FormatLibraryFileName("cupti", GetCudaVersion()),
- GetCudaCuptiLibraryPath()),
- dso_handle);
+ return GetDsoHandle(FindDsoPath(tensorflow::internal::FormatLibraryFileName(
+ "cupti", GetCudaVersion()),
+ GetCudaCuptiLibraryPath()),
+ dso_handle);
}
/* static */ void DsoLoader::RegisterRpath(port::StringPiece path) {
@@ -89,11 +93,9 @@ string GetCudnnVersion() { return ""; }
GetRpaths()->push_back(path.ToString());
}
-
/* static */ port::Status DsoLoader::GetDsoHandle(port::StringPiece path,
void** dso_handle,
LoadKind load_kind) {
-
int dynload_flags =
RTLD_LAZY | (load_kind == LoadKind::kLocal ? RTLD_LOCAL : RTLD_GLOBAL);
string path_string = path.ToString();
@@ -138,9 +140,9 @@ string GetCudnnVersion() { return ""; }
static std::vector<string>* CreatePrimordialRpaths() {
auto rpaths = new std::vector<string>;
#if defined(__APPLE__)
- rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib");
+ rpaths->push_back("driver/driver_sh.runfiles/local_config_cuda/cuda/lib");
#else
- rpaths->push_back("driver/driver_sh.runfiles/org_tensorflow/third_party/gpus/cuda/lib64");
+ rpaths->push_back("driver/driver_sh.runfiles/local_config_cuda/cuda/lib64");
#endif
return rpaths;
}
@@ -165,7 +167,6 @@ static std::vector<string>* CreatePrimordialRpaths() {
/* static */ string DsoLoader::FindDsoPath(port::StringPiece library_name,
port::StringPiece runfiles_relpath) {
-
// Keep a record of the paths we attempted so we can dump out meaningful
// diagnostics if no path is found.
std::vector<string> attempted;
@@ -191,29 +192,28 @@ static std::vector<string>* CreatePrimordialRpaths() {
/* static */ string DsoLoader::GetCudaLibraryDirPath() {
#if defined(__APPLE__)
- return "third_party/gpus/cuda/lib";
+ return "external/local_config_cuda/cuda/lib";
#else
- return "third_party/gpus/cuda/lib64";
+ return "external/local_config_cuda/cuda/lib64";
#endif
}
/* static */ string DsoLoader::GetCudaDriverLibraryPath() {
#if defined(__APPLE__)
- return "third_party/gpus/cuda/driver/lib";
+ return "external/local_config_cuda/cuda/driver/lib";
#else
- return "third_party/gpus/cuda/driver/lib64";
+ return "external/local_config_cuda/cuda/driver/lib64";
#endif
}
/* static */ string DsoLoader::GetCudaCuptiLibraryPath() {
#if defined(__APPLE__)
- return "third_party/gpus/cuda/extras/CUPTI/lib";
+ return "external/local_config_cuda/cuda/extras/CUPTI/lib";
#else
- return "third_party/gpus/cuda/extras/CUPTI/lib64";
+ return "external/local_config_cuda/cuda/extras/CUPTI/lib64";
#endif
}
-
// -- CachedDsoLoader
/* static */ port::StatusOr<void*> CachedDsoLoader::GetCublasDsoHandle() {
diff --git a/tensorflow/stream_executor/dso_loader.h b/tensorflow/stream_executor/dso_loader.h
index a3d4678255..64419e46f9 100644
--- a/tensorflow/stream_executor/dso_loader.h
+++ b/tensorflow/stream_executor/dso_loader.h
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/stream_executor/platform/port.h"
#include <vector>
+#include "cuda/cuda_config.h"
#include "tensorflow/stream_executor/lib/status.h"
#include "tensorflow/stream_executor/lib/statusor.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
diff --git a/tensorflow/tensorboard/backend/handler.py b/tensorflow/tensorboard/backend/handler.py
index a8ecf73a5b..1c1e9411a4 100644
--- a/tensorflow/tensorboard/backend/handler.py
+++ b/tensorflow/tensorboard/backend/handler.py
@@ -32,6 +32,7 @@ import os
import re
from six import BytesIO
+from six import StringIO
from six.moves import BaseHTTPServer
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -276,7 +277,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
values = self._multiplexer.Scalars(run, tag)
if query_params.get('format') == _OutputFormat.CSV:
- string_io = BytesIO()
+ string_io = StringIO()
writer = csv.writer(string_io)
writer.writerow(['Wall time', 'Step', 'Value'])
writer.writerows(values)
@@ -353,7 +354,7 @@ class TensorboardHandler(BaseHTTPServer.BaseHTTPRequestHandler):
run = query_params.get('run')
compressed_histograms = self._multiplexer.CompressedHistograms(run, tag)
if query_params.get('format') == _OutputFormat.CSV:
- string_io = BytesIO()
+ string_io = StringIO()
writer = csv.writer(string_io)
# Build the headers; we have two columns for timing and two columns for
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 898a8b7309..ffa1965c3b 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -32,7 +32,7 @@ load(
"tf_cuda_tests_tags",
)
load(
- "//third_party/gpus/cuda:build_defs.bzl",
+ "@local_config_cuda//cuda:build_defs.bzl",
"if_cuda",
)
@@ -324,11 +324,11 @@ def tf_cc_tests(tests, deps, linkstatic=0, tags=[], size="medium", args=None,
tf_cc_test(t, deps, linkstatic, tags=tags, size=size, args=args,
linkopts=linkopts)
-def tf_cc_tests_gpu(tests, deps, linkstatic=0, tags=[], size="medium", args=None):
+def tf_cc_tests_gpu(tests, deps, linkstatic=0, tags=[], size="medium",
+ args=None):
tf_cc_tests(tests, deps, linkstatic, tags=tags, size=size, args=args)
-
def tf_cuda_cc_tests(tests, deps, tags=[], size="medium", linkstatic=0,
args=None, linkopts=[]):
for t in tests:
@@ -345,29 +345,29 @@ def _cuda_copts():
common_cuda_opts = ["-x", "cuda", "-DGOOGLE_CUDA=1"]
return select({
"//conditions:default": [],
- "//third_party/gpus/cuda:using_nvcc": (
+ "@local_config_cuda//cuda:using_nvcc": (
common_cuda_opts +
[
"-nvcc_options=relaxed-constexpr",
"-nvcc_options=ftz=true",
]
),
- "//third_party/gpus/cuda:using_gcudacc": (
+ "@local_config_cuda//cuda:using_gcudacc": (
common_cuda_opts +
["--gcudacc_flag=-ftz=true"]
),
- "//third_party/gpus/cuda:using_clang": (
+ "@local_config_cuda//cuda:using_clang": (
common_cuda_opts +
[
"-fcuda-flush-denormals-to-zero",
- "--cuda-path=third_party/gpus/cuda",
+ "--cuda-path=external/local_config_cuda/cuda",
"--cuda-gpu-arch=sm_35",
]
),
}) + select({
# Pass -O3 when building CUDA code with clang; some important
# optimizations are not enabled at O2.
- "//third_party/gpus/cuda:using_clang_opt": ["-O3"],
+ "@local_config_cuda//cuda:using_clang_opt": ["-O3"],
"//conditions:default": [],
})
@@ -438,7 +438,8 @@ def tf_kernel_library(name, prefix=None, srcs=None, gpu_srcs=None, hdrs=None,
* srcs = ["cwise_op_abs.cc", ..., "cwise_op_tanh.cc"],
* hdrs = ["cwise_ops.h", "cwise_ops_common.h"],
* gpu_srcs = ["cwise_op_gpu_abs.cu.cc", ..., "cwise_op_gpu_tanh.cu.cc",
- "cwise_ops.h", "cwise_ops_common.h", "cwise_ops_gpu_common.cu.h"]
+ "cwise_ops.h", "cwise_ops_common.h",
+ "cwise_ops_gpu_common.cu.h"]
* "cwise_ops_test.cc" is excluded
"""
if not srcs:
@@ -642,7 +643,7 @@ check_deps = rule(
def tf_custom_op_library(name, srcs=[], gpu_srcs=[], deps=[]):
cuda_deps = [
"//tensorflow/core:stream_executor_headers_lib",
- "//third_party/gpus/cuda:cudart_static",
+ "@local_config_cuda//cuda:cudart_static",
]
deps = deps + tf_custom_op_library_additional_deps()
if gpu_srcs:
@@ -692,7 +693,7 @@ def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
module_name=module_name,
py_module_name=name)
extra_linkopts = select({
- "//third_party/gpus/cuda:darwin": [
+ "@local_config_cuda//cuda:darwin": [
"-Wl,-exported_symbols_list",
"//tensorflow:tf_exported_symbols.lds"
],
@@ -701,7 +702,7 @@ def tf_py_wrap_cc(name, srcs, swig_includes=[], deps=[], copts=[], **kwargs):
"//tensorflow:tf_version_script.lds"
]})
extra_deps += select({
- "//third_party/gpus/cuda:darwin": [
+ "@local_config_cuda//cuda:darwin": [
"//tensorflow:tf_exported_symbols.lds"
],
"//conditions:default": [
@@ -775,13 +776,14 @@ def py_tests(name,
data=data,
additional_deps=additional_deps)
-def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[], shard_count=1, tags=[], prefix=""):
+def cuda_py_tests(name, srcs, size="medium", additional_deps=[], data=[],
+ shard_count=1, tags=[], prefix=""):
test_tags = tags + tf_cuda_tests_tags()
py_tests(name=name, size=size, srcs=srcs, additional_deps=additional_deps,
data=data, tags=test_tags, shard_count=shard_count,prefix=prefix)
-# Creates a genrule named <name> for running tools/proto_text's generator to make
-# the proto_text functions, for the protos passed in <srcs>.
+# Creates a genrule named <name> for running tools/proto_text's generator to
+# make the proto_text functions, for the protos passed in <srcs>.
#
# Return a struct with fields (hdrs, srcs) containing the names of the
# generated files.
diff --git a/tensorflow/tools/ci_build/Dockerfile.gpu b/tensorflow/tools/ci_build/Dockerfile.gpu
index 7b09691178..b36edb9bde 100644
--- a/tensorflow/tools/ci_build/Dockerfile.gpu
+++ b/tensorflow/tools/ci_build/Dockerfile.gpu
@@ -22,5 +22,6 @@ ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64
# Configure the build for our CUDA configuration.
ENV CUDA_TOOLKIT_PATH /usr/local/cuda
-ENV CUDNN_INSTALL_PATH /usr/local/cuda
+ENV CUDNN_INSTALL_PATH /usr/lib/x86_64-linux-gnu
ENV TF_NEED_CUDA 1
+ENV CUDA_COMPUTE_CAPABILITIES 3.0,5.2
diff --git a/tensorflow/tools/dist_test/README.md b/tensorflow/tools/dist_test/README.md
index b7f411f847..b042bcf3a2 100644
--- a/tensorflow/tools/dist_test/README.md
+++ b/tensorflow/tools/dist_test/README.md
@@ -35,7 +35,7 @@ For example:
export TF_DIST_GCLOUD_PROJECT="tensorflow-testing"
export TF_DIST_GCLOUD_COMPUTE_ZONE="us-central1-f"
- export CONTAINER_CLUSTER="test-cluster-1"
+ export TF_DIST_CONTAINER_CLUSTER="test-cluster-1"
export TF_DIST_GCLOUD_KEY_FILE_DIR="/tmp/gcloud-secrets"
./remote_test.sh
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 3b2112a026..d47c446bbe 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -1,9 +1,12 @@
# TensorFlow external dependencies that can be loaded in WORKSPACE files.
+load("//third_party/gpus:cuda_configure.bzl", "cuda_configure")
+
# If TensorFlow is linked as a submodule, path_prefix is TensorFlow's directory
# within the workspace (e.g. "tensorflow/"), and tf_repo_name is the name of the
# local_repository rule (e.g. "@tf").
def tf_workspace(path_prefix = "", tf_repo_name = ""):
+ cuda_configure(name = "local_config_cuda")
# These lines need to be changed when updating Eigen. They are parsed from
# this file by the cmake and make builds to determine the eigen version and hash.
diff --git a/third_party/gpus/BUILD b/third_party/gpus/BUILD
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/third_party/gpus/BUILD
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
new file mode 100644
index 0000000000..7c9c8ab884
--- /dev/null
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -0,0 +1,42 @@
+licenses(["restricted"])
+
+package(default_visibility = ["//visibility:public"])
+
+filegroup(
+ name = "crosstool",
+ srcs = ["CROSSTOOL"],
+ output_licenses = ["unencumbered"],
+)
+
+cc_toolchain(
+ name = "cc-compiler-local",
+ all_files = ":empty",
+ compiler_files = ":empty",
+ cpu = "local",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":empty",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 0,
+)
+
+cc_toolchain(
+ name = "cc-compiler-darwin",
+ all_files = ":empty",
+ compiler_files = ":empty",
+ cpu = "darwin",
+ dwp_files = ":empty",
+ dynamic_runtime_libs = [":empty"],
+ linker_files = ":empty",
+ objcopy_files = ":empty",
+ static_runtime_libs = [":empty"],
+ strip_files = ":empty",
+ supports_param_files = 0,
+)
+
+filegroup(
+ name = "empty",
+ srcs = [],
+)
diff --git a/third_party/gpus/crosstool/CROSSTOOL.tpl b/third_party/gpus/crosstool/CROSSTOOL.tpl
new file mode 100644
index 0000000000..a367aa8f66
--- /dev/null
+++ b/third_party/gpus/crosstool/CROSSTOOL.tpl
@@ -0,0 +1,254 @@
+major_version: "local"
+minor_version: ""
+default_target_cpu: "same_as_host"
+
+default_toolchain {
+ cpu: "k8"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "piii"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "arm"
+ toolchain_identifier: "local_linux"
+}
+default_toolchain {
+ cpu: "darwin"
+ toolchain_identifier: "local_darwin"
+}
+default_toolchain {
+ cpu: "ppc"
+ toolchain_identifier: "local_linux"
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ builtin_sysroot: ""
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ supports_gold_linker: false
+ supports_incremental_linker: false
+ supports_fission: false
+ supports_interface_shared_objects: false
+ supports_normalizing_ar: false
+ supports_start_end_lib: false
+ supports_thin_archives: false
+ target_libc: "local"
+ target_cpu: "local"
+ target_system_name: "local"
+ toolchain_identifier: "local_linux"
+
+ tool_path { name: "ar" path: "/usr/bin/ar" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ # As part of the TensorFlow release, we place some cuda-related compilation
+ # files in @local_config_cuda//crosstool/clang/bin, and this relative
+ # path, combined with the rest of our Bazel configuration causes our
+ # compilation to use those files.
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+ # Use "-std=c++11" for nvcc. For consistency, force both the host compiler
+ # and the device compiler to use "-std=c++11".
+ cxx_flag: "-std=c++11"
+ linker_flag: "-lstdc++"
+ linker_flag: "-B/usr/bin/"
+
+ # TODO(bazel-team): In theory, the path here ought to exactly match the path
+ # used by gcc. That works because bazel currently doesn't track files at
+ # absolute locations and has no remote execution, yet. However, this will need
+ # to be fixed, maybe with auto-detection?
+ cxx_builtin_include_directory: "/usr/lib/gcc/"
+ cxx_builtin_include_directory: "/usr/local/include"
+ cxx_builtin_include_directory: "/usr/include"
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+
+ # C(++) compiles invoke the compiler (as that is the one knowing where
+ # to find libraries), but we provide LD so other rules can invoke the linker.
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ objcopy_embed_flag: "-I"
+ objcopy_embed_flag: "binary"
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Anticipated future default.
+ unfiltered_cxx_flag: "-no-canonical-prefixes"
+
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
+ unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
+
+ # Security hardening on by default.
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ # We need to undef it before redefining it as some distributions now have
+ # it enabled by default.
+ compiler_flag: "-U_FORTIFY_SOURCE"
+ compiler_flag: "-D_FORTIFY_SOURCE=1"
+ compiler_flag: "-fstack-protector"
+ compiler_flag: "-fPIE"
+ linker_flag: "-pie"
+ linker_flag: "-Wl,-z,relro,-z,now"
+
+ # Enable coloring even if there's no attached terminal. Bazel removes the
+ # escape sequences if --nocolor is specified. This isn't supported by gcc
+ # on Ubuntu 14.04.
+ # compiler_flag: "-fcolor-diagnostics"
+
+ # All warnings are enabled. Maybe enable -Werror as well?
+ compiler_flag: "-Wall"
+ # Enable a few more warnings that aren't part of -Wall.
+ compiler_flag: "-Wunused-but-set-parameter"
+ # But disable some that are problematic.
+ compiler_flag: "-Wno-free-nonheap-object" # has false positives
+
+ # Keep stack frames for debugging, even in opt mode.
+ compiler_flag: "-fno-omit-frame-pointer"
+
+ # Anticipated future default.
+ linker_flag: "-no-canonical-prefixes"
+ unfiltered_cxx_flag: "-fno-canonical-system-headers"
+ # Have gcc return the exit code from ld.
+ linker_flag: "-pass-exit-codes"
+ # Stamp the binary with a unique identifier.
+ linker_flag: "-Wl,--build-id=md5"
+ linker_flag: "-Wl,--hash-style=gnu"
+ # Gold linker only? Can we enable this by default?
+ # linker_flag: "-Wl,--warn-execstack"
+ # linker_flag: "-Wl,--detect-odr-violations"
+
+ # Include directory for cuda headers.
+ cxx_builtin_include_directory: "/usr/local/cuda%{cuda_version}/include"
+
+ compilation_mode_flags {
+ mode: DBG
+ # Enable debug symbols.
+ compiler_flag: "-g"
+ }
+ compilation_mode_flags {
+ mode: OPT
+
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or
+ # even generally? However, that can't happen here, as it requires special
+ # handling in Bazel.
+ compiler_flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ compiler_flag: "-O2"
+
+ # Disable assertions
+ compiler_flag: "-DNDEBUG"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ compiler_flag: "-ffunction-sections"
+ compiler_flag: "-fdata-sections"
+ linker_flag: "-Wl,--gc-sections"
+ }
+ linking_mode_flags { mode: DYNAMIC }
+}
+
+toolchain {
+ abi_version: "local"
+ abi_libc_version: "local"
+ builtin_sysroot: ""
+ compiler: "compiler"
+ host_system_name: "local"
+ needsPic: true
+ target_libc: "macosx"
+ target_cpu: "darwin"
+ target_system_name: "local"
+ toolchain_identifier: "local_darwin"
+
+ tool_path { name: "ar" path: "/usr/bin/libtool" }
+ tool_path { name: "compat-ld" path: "/usr/bin/ld" }
+ tool_path { name: "cpp" path: "/usr/bin/cpp" }
+ tool_path { name: "dwp" path: "/usr/bin/dwp" }
+ tool_path { name: "gcc" path: "clang/bin/crosstool_wrapper_driver_is_not_gcc" }
+ cxx_flag: "-std=c++11"
+ ar_flag: "-static"
+ ar_flag: "-s"
+ ar_flag: "-o"
+ linker_flag: "-lc++"
+ linker_flag: "-undefined"
+ linker_flag: "dynamic_lookup"
+ # TODO(ulfjack): This is wrong on so many levels. Figure out a way to auto-detect the proper
+ # setting from the local compiler, and also how to make incremental builds correct.
+ cxx_builtin_include_directory: "/"
+ tool_path { name: "gcov" path: "/usr/bin/gcov" }
+ tool_path { name: "ld" path: "/usr/bin/ld" }
+ tool_path { name: "nm" path: "/usr/bin/nm" }
+ tool_path { name: "objcopy" path: "/usr/bin/objcopy" }
+ objcopy_embed_flag: "-I"
+ objcopy_embed_flag: "binary"
+ tool_path { name: "objdump" path: "/usr/bin/objdump" }
+ tool_path { name: "strip" path: "/usr/bin/strip" }
+
+ # Anticipated future default.
+ unfiltered_cxx_flag: "-no-canonical-prefixes"
+ # Make C++ compilation deterministic. Use linkstamping instead of these
+ # compiler symbols.
+ unfiltered_cxx_flag: "-Wno-builtin-macro-redefined"
+ unfiltered_cxx_flag: "-D__DATE__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIMESTAMP__=\"redacted\""
+ unfiltered_cxx_flag: "-D__TIME__=\"redacted\""
+
+ # Security hardening on by default.
+ # Conservative choice; -D_FORTIFY_SOURCE=2 may be unsafe in some cases.
+ compiler_flag: "-D_FORTIFY_SOURCE=1"
+ compiler_flag: "-fstack-protector"
+
+ # Enable coloring even if there's no attached terminal. Bazel removes the
+ # escape sequences if --nocolor is specified.
+ compiler_flag: "-fcolor-diagnostics"
+
+ # All warnings are enabled. Maybe enable -Werror as well?
+ compiler_flag: "-Wall"
+ # Enable a few more warnings that aren't part of -Wall.
+ compiler_flag: "-Wthread-safety"
+ compiler_flag: "-Wself-assign"
+
+ # Keep stack frames for debugging, even in opt mode.
+ compiler_flag: "-fno-omit-frame-pointer"
+
+ # Anticipated future default.
+ linker_flag: "-no-canonical-prefixes"
+
+ # Include directory for cuda headers.
+ cxx_builtin_include_directory: "/usr/local/cuda%{cuda_version}/include"
+
+ compilation_mode_flags {
+ mode: DBG
+ # Enable debug symbols.
+ compiler_flag: "-g"
+ }
+ compilation_mode_flags {
+ mode: OPT
+ # No debug symbols.
+ # Maybe we should enable https://gcc.gnu.org/wiki/DebugFission for opt or even generally?
+ # However, that can't happen here, as it requires special handling in Bazel.
+ compiler_flag: "-g0"
+
+ # Conservative choice for -O
+ # -O3 can increase binary size and even slow down the resulting binaries.
+ # Profile first and / or use FDO if you need better performance than this.
+ compiler_flag: "-O2"
+
+ # Disable assertions
+ compiler_flag: "-DNDEBUG"
+
+ # Removal of unused code and data at link time (can this increase binary size in some cases?).
+ compiler_flag: "-ffunction-sections"
+ compiler_flag: "-fdata-sections"
+ }
+}
diff --git a/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
new file mode 100755
index 0000000000..20449a1137
--- /dev/null
+++ b/third_party/gpus/crosstool/clang/bin/crosstool_wrapper_driver_is_not_gcc.tpl
@@ -0,0 +1,315 @@
+#!/usr/bin/env python
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Crosstool wrapper for compiling CUDA programs.
+
+SYNOPSIS:
+ crosstool_wrapper_is_not_gcc [options passed in by cc_library()
+ or cc_binary() rule]
+
+DESCRIPTION:
+ This script is expected to be called by the cc_library() or cc_binary() bazel
+ rules. When the option "-x cuda" is present in the list of arguments passed
+ to this script, it invokes the nvcc CUDA compiler. Most arguments are passed
+ as is as a string to --compiler-options of nvcc. When "-x cuda" is not
+ present, this wrapper invokes hybrid_driver_is_not_gcc with the input
+ arguments as is.
+
+NOTES:
+ Changes to the contents of this file must be propagated from
+ //third_party/gpus/crosstool/crosstool_wrapper_is_not_gcc to
+ //third_party/gpus/crosstool/v*/*/clang/bin/crosstool_wrapper_is_not_gcc
+"""
+
+from __future__ import print_function
+
+__author__ = 'keveman@google.com (Manjunath Kudlur)'
+
+from argparse import ArgumentParser
+import os
+import subprocess
+import re
+import sys
+import pipes
+
+# Template values set by cuda_autoconf.
+CPU_COMPILER = ('%{cpu_compiler}')
+GCC_HOST_COMPILER_PATH = ('%{gcc_host_compiler_path}')
+
+CURRENT_DIR = os.path.dirname(sys.argv[0])
+NVCC_PATH = CURRENT_DIR + '/../../../cuda/bin/nvcc'
+LLVM_HOST_COMPILER_PATH = ('/usr/bin/gcc')
+PREFIX_DIR = os.path.dirname(GCC_HOST_COMPILER_PATH)
+
+def Log(s):
+ print('gpus/crosstool: {0}'.format(s))
+
+
+def GetOptionValue(argv, option):
+ """Extract the list of values for option from the argv list.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ option: The option whose value to extract, without the leading '-'.
+
+ Returns:
+ A list of values, either directly following the option,
+ (eg., -opt val1 val2) or values collected from multiple occurrences of
+ the option (eg., -opt val1 -opt val2).
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-' + option, nargs='*', action='append')
+ args, _ = parser.parse_known_args(argv)
+ if not args or not vars(args)[option]:
+ return []
+ else:
+ return sum(vars(args)[option], [])
+
+
+def GetHostCompilerOptions(argv):
+ """Collect the -isystem, -iquote, and --sysroot option values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be used as the --compiler-options to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-isystem', nargs='*', action='append')
+ parser.add_argument('-iquote', nargs='*', action='append')
+ parser.add_argument('--sysroot', nargs=1)
+ parser.add_argument('-g', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ opts = ''
+
+ if args.isystem:
+ opts += ' -isystem ' + ' -isystem '.join(sum(args.isystem, []))
+ if args.iquote:
+ opts += ' -iquote ' + ' -iquote '.join(sum(args.iquote, []))
+ if args.g:
+ opts += ' -g' + ' -g'.join(sum(args.g, []))
+ if args.sysroot:
+ opts += ' --sysroot ' + args.sysroot[0]
+
+ return opts
+
+def GetNvccOptions(argv):
+ """Collect the -nvcc_options values from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ The string that can be passed directly to nvcc.
+ """
+
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='append')
+
+ args, _ = parser.parse_known_args(argv)
+
+ if args.nvcc_options:
+ return ' '.join(['--'+a for a in sum(args.nvcc_options, [])])
+ return ''
+
+
+def StripAndTransformNvccOptions(argv):
+ """Strips the -nvcc_options values from argv and transforms define-macros.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+
+ Returns:
+ A list of strings that can be passed directly to gcudacc.
+ """
+ parser = ArgumentParser()
+ parser.add_argument('-nvcc_options', nargs='*', action='store')
+ args, leftover = parser.parse_known_args(argv)
+ if args.nvcc_options:
+ for option in args.nvcc_options:
+ (flag, _, value) = option.partition('=')
+ if 'define-macro' in flag:
+ leftover.append('-D' + value)
+ return leftover
+
+
+def InvokeGcudacc(argv, gcudacc_version, gcudacc_flags, log=False):
+ """Call gcudacc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ gcudacc_version: The version of gcudacc; this is a subdirectory name under
+ the gcudacc bin/ directory.
+ gcudacc_flags: A list of extra arguments passed just for gcudacc.
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('gcudacc ' + args)
+ """
+
+ gcudacc_cmd = os.path.join(GCUDACC_PATH_BASE, gcudacc_version, 'gcudacc.par')
+ gcudacc_cmd = (
+ gcudacc_cmd +
+ ' --google_host_compiler={0} '.format(LLVM_HOST_COMPILER_PATH) +
+ ' '.join(sum(gcudacc_flags, [])) +
+ ' -- ' +
+ ' '.join(StripAndTransformNvccOptions(argv)))
+ if log: Log(gcudacc_cmd)
+ return os.system(gcudacc_cmd)
+
+
+def InvokeNvcc(argv, log=False):
+ """Call nvcc with arguments assembled from argv.
+
+ Args:
+ argv: A list of strings, possibly the argv passed to main().
+ log: True if logging is requested.
+
+ Returns:
+ The return value of calling os.system('nvcc ' + args)
+ """
+
+ host_compiler_options = GetHostCompilerOptions(argv)
+ nvcc_compiler_options = GetNvccOptions(argv)
+ opt_option = GetOptionValue(argv, 'O')
+ m_options = GetOptionValue(argv, 'm')
+ m_options = ''.join([' -m' + m for m in m_options if m in ['32', '64']])
+ include_options = GetOptionValue(argv, 'I')
+ out_file = GetOptionValue(argv, 'o')
+ depfiles = GetOptionValue(argv, 'MF')
+ defines = GetOptionValue(argv, 'D')
+ defines = ''.join([' -D' + define for define in defines])
+ undefines = GetOptionValue(argv, 'U')
+ undefines = ''.join([' -U' + define for define in undefines])
+ std_options = GetOptionValue(argv, 'std')
+ # currently only c++11 is supported by Cuda 7.0 std argument
+ nvcc_allowed_std_options = ["c++11"]
+ std_options = ''.join([' -std=' + define
+ for define in std_options if define in nvcc_allowed_std_options])
+
+ # The list of source files get passed after the -c option. I don't know of
+ # any other reliable way to just get the list of source files to be compiled.
+ src_files = GetOptionValue(argv, 'c')
+
+ if len(src_files) == 0:
+ return 1
+ if len(out_file) != 1:
+ return 1
+
+ opt = (' -O2' if (len(opt_option) > 0 and int(opt_option[0]) > 0)
+ else ' -g -G')
+
+ includes = (' -I ' + ' -I '.join(include_options)
+ if len(include_options) > 0
+ else '')
+
+ # Unfortunately, there are other options that have -c prefix too.
+ # So allowing only those look like C/C++ files.
+ src_files = [f for f in src_files if
+ re.search('\.cpp$|\.cc$|\.c$|\.cxx$|\.C$', f)]
+ srcs = ' '.join(src_files)
+ out = ' -o ' + out_file[0]
+
+ supported_cuda_compute_capabilities = [ %{cuda_compute_capabilities} ]
+ nvccopts = ''
+ for capability in supported_cuda_compute_capabilities:
+ capability = capability.replace('.', '')
+ nvccopts += r'-gencode=arch=compute_%s,\"code=sm_%s,compute_%s\" ' % (
+ capability, capability, capability)
+ nvccopts += ' ' + nvcc_compiler_options
+ nvccopts += undefines
+ nvccopts += defines
+ nvccopts += std_options
+ nvccopts += m_options
+
+ if depfiles:
+ # Generate the dependency file
+ depfile = depfiles[0]
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + '"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + includes + ' ' + srcs + ' -M -o ' + depfile)
+ if log: Log(cmd)
+ exit_status = os.system(cmd)
+ if exit_status != 0:
+ return exit_status
+
+ cmd = (NVCC_PATH + ' ' + nvccopts +
+ ' --compiler-options "' + host_compiler_options + ' -fPIC"' +
+ ' --compiler-bindir=' + GCC_HOST_COMPILER_PATH +
+ ' -I .' +
+ ' -x cu ' + opt + includes + ' -c ' + srcs + out)
+
+ # TODO(zhengxq): for some reason, 'gcc' needs this help to find 'as'.
+ # Need to investigate and fix.
+ cmd = 'PATH=' + PREFIX_DIR + ' ' + cmd
+ if log: Log(cmd)
+ return os.system(cmd)
+
+
+def main():
+ parser = ArgumentParser()
+ parser.add_argument('-x', nargs=1)
+ parser.add_argument('--cuda_log', action='store_true')
+ parser.add_argument('--use_gcudacc', action='store_true')
+ parser.add_argument('--gcudacc_version', action='store', default='v8')
+ parser.add_argument('--gcudacc_flag', nargs='*', action='append', default=[])
+ args, leftover = parser.parse_known_args(sys.argv[1:])
+
+ if args.x and args.x[0] == 'cuda':
+ if args.cuda_log: Log('-x cuda')
+ leftover = [pipes.quote(s) for s in leftover]
+ if args.use_gcudacc:
+ if args.cuda_log: Log('using gcudacc')
+ return InvokeGcudacc(argv=leftover,
+ gcudacc_version=args.gcudacc_version,
+ gcudacc_flags=args.gcudacc_flag,
+ log=args.cuda_log)
+ if args.cuda_log: Log('using nvcc')
+ return InvokeNvcc(leftover, log=args.cuda_log)
+
+ # Strip our flags before passing through to the CPU compiler for files which
+ # are not -x cuda. We can't just pass 'leftover' because it also strips -x.
+ # We not only want to pass -x to the CPU compiler, but also keep it in its
+ # relative location in the argv list (the compiler is actually sensitive to
+ # this).
+ cpu_compiler_flags = [flag for flag in sys.argv[1:]
+ if not flag.startswith(('--cuda_log',
+ '--use_gcudacc',
+ '--gcudacc_version',
+ '--gcudacc_flag'))]
+ if args.use_gcudacc:
+ # This macro is defined for TUs that are not marked with "-x cuda" but are
+ # built as part of a -config=cuda --use_gcudacc compilation. They are
+ # compiled with the default CPU compiler. Since the objects built from
+ # these TUs are later linked with objects that come from gcudacc, some
+ # parts of the code need to be marked for these special cases. For example,
+ # some types have to be defined similarly for gcudacc-compiled TUs and
+ # default CPU compiler-compiled TUs linked with them, but differently when
+ # nvcc is used.
+ # TODO(eliben): rename to a more descriptive name.
+ cpu_compiler_flags.append('-D__GCUDACC_HOST__')
+
+ return subprocess.call([CPU_COMPILER] + cpu_compiler_flags)
+
+if __name__ == '__main__':
+ sys.exit(main())
diff --git a/third_party/gpus/cuda/BUILD b/third_party/gpus/cuda/BUILD
index 79c6227687..e69de29bb2 100644
--- a/third_party/gpus/cuda/BUILD
+++ b/third_party/gpus/cuda/BUILD
@@ -1,224 +0,0 @@
-licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
-
-load("//third_party/gpus/cuda:build_defs.bzl", "if_cuda")
-load("platform", "cuda_library_path")
-load("platform", "cuda_static_library_path")
-load("platform", "cudnn_library_path")
-load("platform", "cupti_library_path")
-load("platform", "readlink_command")
-
-package(default_visibility = ["//visibility:public"])
-
-config_setting(
- name = "using_gcudacc",
- values = {
- "define": "using_cuda_gcudacc=true",
- },
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "using_nvcc",
- values = {
- "define": "using_cuda_nvcc=true",
- },
-)
-
-config_setting(
- name = "using_clang",
- values = {
- "define": "using_cuda_clang=true",
- },
-)
-
-# Equivalent to using_clang && -c opt.
-config_setting(
- name = "using_clang_opt",
- values = {
- "define": "using_cuda_clang=true",
- "compilation_mode": "opt",
- },
-)
-
-config_setting(
- name = "darwin",
- values = {"cpu": "darwin"},
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cuda_headers",
- hdrs = glob([
- "**/*.h",
- ]),
- includes = [
- ".",
- "include",
- ],
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cudart_static",
- srcs = [
- cuda_static_library_path("cudart"),
- ],
- includes = ["include/"],
- linkopts = [
- "-ldl",
- "-lpthread",
- ] + select({
- "//tensorflow:darwin": [],
- "//conditions:default": ["-lrt"],
- }),
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cudart",
- srcs = [
- cuda_library_path("cudart"),
- ],
- data = [
- cuda_library_path("cudart"),
- ],
- includes = ["include/"],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cublas",
- srcs = [
- cuda_library_path("cublas"),
- ],
- data = [
- cuda_library_path("cublas"),
- ],
- includes = ["include/"],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cudnn",
- srcs = [
- cudnn_library_path(),
- ],
- data = [
- cudnn_library_path(),
- ],
- includes = ["include/"],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cufft",
- srcs = [
- cuda_library_path("cufft"),
- ],
- data = [
- cuda_library_path("cufft"),
- ],
- includes = ["include/"],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cuda",
- visibility = ["//visibility:public"],
- deps = [
- ":cublas",
- ":cuda_headers",
- ":cudart",
- ":cudnn",
- ":cufft",
- ],
-)
-
-cc_library(
- name = "cupti_headers",
- hdrs = glob([
- "**/*.h",
- ]),
- includes = [
- ".",
- "extras/CUPTI/include/",
- ],
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "cupti_dsos",
- data = [
- cupti_library_path(),
- ],
- visibility = ["//visibility:public"],
-)
-
-# TODO(opensource): for now, we have to invoke the cuda_config.sh manually in the source tree.
-# This rule checks if Cuda libraries in the source tree has been properly configured.
-# The output list makes bazel runs this rule first if the Cuda files are missing.
-# This gives us an opportunity to check and print a meaningful error message.
-# But we will need to create the output file list to make bazel happy in a successful run.
-genrule(
- name = "cuda_check",
- srcs = [
- "cuda.config",
- "cuda_config.sh",
- ],
- outs = [
- "include/cuda.h",
- "include/cublas.h",
- "include/cudnn.h",
- "extras/CUPTI/include/cupti.h",
- cuda_static_library_path("cudart"),
- cuda_library_path("cublas"),
- cudnn_library_path(),
- cuda_library_path("cudart"),
- cuda_library_path("cufft"),
- cupti_library_path(),
- ],
- cmd = if_cuda(
- # Under cuda config, create all the symbolic links to the actual cuda files
- "OUTPUTDIR=`{} -f $(@D)/../../..`; cd `dirname $(location :cuda_config.sh)`; OUTPUTDIR=$$OUTPUTDIR ./cuda_config.sh --check;".format(readlink_command()),
-
- # Under non-cuda config, create all dummy files to make the build go through
- ";".join([
- "mkdir -p $(@D)/include",
- "mkdir -p $(@D)/lib64",
- "mkdir -p $(@D)/extras/CUPTI/include",
- "mkdir -p $(@D)/extras/CUPTI/lib64",
- "touch $(@D)/include/cuda.h",
- "touch $(@D)/include/cublas.h",
- "touch $(@D)/include/cudnn.h",
- "touch $(@D)/extras/CUPTI/include/cupti.h",
- "touch $(@D)/{}".format(cuda_static_library_path("cudart")),
- "touch $(@D)/{}".format(cuda_library_path("cublas")),
- "touch $(@D)/{}".format(cudnn_library_path()),
- "touch $(@D)/{}".format(cuda_library_path("cudart")),
- "touch $(@D)/{}".format(cuda_library_path("cufft")),
- "touch $(@D)/{}".format(cupti_library_path()),
- ]),
- ),
- local = 1,
-)
-
-genrule(
- name = "cuda_config_check",
- outs = [
- "cuda.config",
- ],
- cmd = if_cuda(
- # Under cuda config, create the symbolic link to the actual cuda.config
- "configfile=$(location :cuda.config); ln -sf `{} -f $${{configfile#*/*/*/}}` $(@D)/;".format(readlink_command()),
-
- # Under non-cuda config, create the dummy file
- ";".join([
- "touch $(@D)/cuda.config",
- ]),
- ),
- local = 1,
-)
diff --git a/third_party/gpus/cuda/BUILD.tpl b/third_party/gpus/cuda/BUILD.tpl
new file mode 100644
index 0000000000..ab98c2d528
--- /dev/null
+++ b/third_party/gpus/cuda/BUILD.tpl
@@ -0,0 +1,172 @@
+licenses(["restricted"]) # MPL2, portions GPL v3, LGPL v3, BSD-like
+
+load("@local_config_cuda//cuda:platform.bzl", "cuda_library_path")
+load("@local_config_cuda//cuda:platform.bzl", "cuda_static_library_path")
+load("@local_config_cuda//cuda:platform.bzl", "cudnn_library_path")
+load("@local_config_cuda//cuda:platform.bzl", "cupti_library_path")
+load("@local_config_cuda//cuda:platform.bzl", "readlink_command")
+
+package(default_visibility = ["//visibility:public"])
+
+config_setting(
+ name = "using_gcudacc",
+ values = {
+ "define": "using_cuda_gcudacc=true",
+ },
+ visibility = ["//visibility:public"],
+)
+
+config_setting(
+ name = "using_nvcc",
+ values = {
+ "define": "using_cuda_nvcc=true",
+ },
+)
+
+config_setting(
+ name = "using_clang",
+ values = {
+ "define": "using_cuda_clang=true",
+ },
+)
+
+# Equivalent to using_clang && -c opt.
+config_setting(
+ name = "using_clang_opt",
+ values = {
+ "define": "using_cuda_clang=true",
+ "compilation_mode": "opt",
+ },
+)
+
+config_setting(
+ name = "darwin",
+ values = {"cpu": "darwin"},
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda_headers",
+ hdrs = glob([
+ "**/*.h",
+ ]),
+ includes = [
+ ".",
+ "include",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart_static",
+ srcs = [
+ cuda_static_library_path("cudart"),
+ ],
+ includes = ["include/"],
+ linkopts = [
+ "-ldl",
+ "-lpthread",
+ ] + select({
+ "@//tensorflow:darwin": [],
+ "//conditions:default": ["-lrt"],
+ }),
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudart",
+ srcs = [
+ cuda_library_path("cudart"),
+ ],
+ data = [
+ cuda_library_path("cudart"),
+ ],
+ includes = ["include/"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cublas",
+ srcs = [
+ cuda_library_path("cublas"),
+ ],
+ data = [
+ cuda_library_path("cublas"),
+ ],
+ includes = ["include/"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cudnn",
+ srcs = [
+ cudnn_library_path(),
+ ],
+ data = [
+ cudnn_library_path(),
+ ],
+ includes = ["include/"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cufft",
+ srcs = [
+ cuda_library_path("cufft"),
+ ],
+ data = [
+ cuda_library_path("cufft"),
+ ],
+ includes = ["include/"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "curand",
+ srcs = [
+ cuda_library_path("curand"),
+ ],
+ data = [
+ cuda_library_path("curand"),
+ ],
+ includes = ["include/"],
+ linkstatic = 1,
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cuda",
+ deps = [
+ ":cuda_headers",
+ ":cudart",
+ ":cublas",
+ ":cudnn",
+ ":cufft",
+ ":curand",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cupti_headers",
+ hdrs = glob([
+ "**/*.h",
+ ]),
+ includes = [
+ ".",
+ "extras/CUPTI/include/",
+ ],
+ visibility = ["//visibility:public"],
+)
+
+cc_library(
+ name = "cupti_dsos",
+ data = [
+ cupti_library_path(),
+ ],
+ visibility = ["//visibility:public"],
+) \ No newline at end of file
diff --git a/third_party/gpus/cuda/build_defs.bzl.tpl b/third_party/gpus/cuda/build_defs.bzl.tpl
new file mode 100644
index 0000000000..8f7dc00630
--- /dev/null
+++ b/third_party/gpus/cuda/build_defs.bzl.tpl
@@ -0,0 +1,14 @@
+# Macros for building CUDA code.
+
+def if_cuda(if_true, if_false = []):
+ """Shorthand for select()'ing on whether we're building with CUDA.
+
+ Returns a select statement which evaluates to if_true if we're building
+ with CUDA enabled. Otherwise, the select statement evaluates to if_false.
+
+ """
+ return select({
+ "@local_config_cuda//cuda:using_nvcc": if_true,
+ "@local_config_cuda//cuda:using_gcudacc": if_true,
+ "//conditions:default": if_false
+ })
diff --git a/third_party/gpus/cuda/cuda_config.h.tpl b/third_party/gpus/cuda/cuda_config.h.tpl
new file mode 100644
index 0000000000..ea51fbb26f
--- /dev/null
+++ b/third_party/gpus/cuda/cuda_config.h.tpl
@@ -0,0 +1,24 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef CUDA_CUDA_CONFIG_H_
+#define CUDA_CUDA_CONFIG_H_
+
+#define TF_CUDA_CAPABILITIES %{cuda_compute_capabilities}
+
+#define TF_CUDA_VERSION "%{cuda_version}"
+#define TF_CUDNN_VERSION "%{cudnn_version}"
+
+#endif // CUDA_CUDA_CONFIG_H_
diff --git a/third_party/gpus/cuda/platform.bzl.tpl b/third_party/gpus/cuda/platform.bzl.tpl
new file mode 100644
index 0000000000..7565dfc129
--- /dev/null
+++ b/third_party/gpus/cuda/platform.bzl.tpl
@@ -0,0 +1,57 @@
+CUDA_VERSION = "%{cuda_version}"
+CUDNN_VERSION = "%{cudnn_version}"
+PLATFORM = "%{platform}"
+
+def cuda_sdk_version():
+ return CUDA_VERSION
+
+def cudnn_sdk_version():
+ return CUDNN_VERSION
+
+def cuda_library_path(name, version = cuda_sdk_version()):
+ if PLATFORM == "Darwin":
+ if not version:
+ return "lib/lib{}.dylib".format(name)
+ else:
+ return "lib/lib{}.{}.dylib".format(name, version)
+ else:
+ if not version:
+ return "lib64/lib{}.so".format(name)
+ else:
+ return "lib64/lib{}.so.{}".format(name, version)
+
+def cuda_static_library_path(name):
+ if PLATFORM == "Darwin":
+ return "lib/lib{}_static.a".format(name)
+ else:
+ return "lib64/lib{}_static.a".format(name)
+
+def cudnn_library_path(version = cudnn_sdk_version()):
+ if PLATFORM == "Darwin":
+ if not version:
+ return "lib/libcudnn.dylib"
+ else:
+ return "lib/libcudnn.{}.dylib".format(version)
+ else:
+ if not version:
+ return "lib64/libcudnn.so"
+ else:
+ return "lib64/libcudnn.so.{}".format(version)
+
+def cupti_library_path(version = cuda_sdk_version()):
+ if PLATFORM == "Darwin":
+ if not version:
+ return "extras/CUPTI/lib/libcupti.dylib"
+ else:
+ return "extras/CUPTI/lib/libcupti.{}.dylib".format(version)
+ else:
+ if not version:
+ return "extras/CUPTI/lib64/libcupti.so"
+ else:
+ return "extras/CUPTI/lib64/libcupti.so.{}".format(version)
+
+def readlink_command():
+ if PLATFORM == "Darwin":
+ return "greadlink"
+ else:
+ return "readlink"
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
new file mode 100644
index 0000000000..3682cb305d
--- /dev/null
+++ b/third_party/gpus/cuda_configure.bzl
@@ -0,0 +1,423 @@
+# -*- Python -*-
+"""Repository rule for CUDA autoconfiguration.
+
+`cuda_configure` depends on the following environment variables:
+
+ * `ENABLE_CUDA`: Whether to enable building with CUDA.
+ * `CC`: The GCC host compiler path
+ * `CUDA_TOOLKIT_PATH`: The path to the CUDA toolkit. Default is
+ `/usr/local/cuda`.
+ * `CUDA_VERSION`: The version of the CUDA toolkit. If this is blank, then
+ use the system default.
+ * `CUDNN_VERSION`: The version of the cuDNN library.
+ * `CUDNN_INSTALL_PATH`: The path to the cuDNN library. Default is
+ `/usr/local/cuda`.
+ * `CUDA_COMPUTE_CAPABILITIES`: The CUDA compute capabilities. Default is
+ `3.5,5.2`.
+"""
+
+
+_DEFAULT_CUDA_VERSION = ""
+_DEFAULT_CUDNN_VERSION = ""
+_DEFAULT_CUDA_TOOLKIT_PATH = "/usr/local/cuda"
+_DEFAULT_CUDNN_INSTALL_PATH = "/usr/local/cuda"
+_DEFAULT_CUDA_COMPUTE_CAPABILITIES = ["3.5", "5.2"]
+
+
+# TODO(dzc): Once these functions have been factored out of Bazel's
+# cc_configure.bzl, load them from @bazel_tools instead.
+# BEGIN cc_configure common functions.
+def find_cc(repository_ctx):
+ """Find the C++ compiler."""
+ cc_name = "gcc"
+ if "CC" in repository_ctx.os.environ:
+ cc_name = repository_ctx.os.environ["CC"].strip()
+ if not cc_name:
+ cc_name = "gcc"
+ if cc_name.startswith("/"):
+ # Absolute path, maybe we should make this suported by our which function.
+ return cc_name
+ cc = repository_ctx.which(cc_name)
+ if cc == None:
+ fail(
+ "Cannot find gcc, either correct your path or set the CC" +
+ " environment variable")
+ return cc
+
+
+_INC_DIR_MARKER_BEGIN = "#include <...>"
+
+
+# OSX add " (framework directory)" at the end of line, strip it.
+_OSX_FRAMEWORK_SUFFIX = " (framework directory)"
+_OSX_FRAMEWORK_SUFFIX_LEN = len(_OSX_FRAMEWORK_SUFFIX)
+def _cxx_inc_convert(path):
+ """Convert path returned by cc -E xc++ in a complete path."""
+ path = path.strip()
+ if path.endswith(_OSX_FRAMEWORK_SUFFIX):
+ path = path[:-_OSX_FRAMEWORK_SUFFIX_LEN].strip()
+ return path
+
+
+def get_cxx_inc_directories(repository_ctx, cc):
+ """Compute the list of default C++ include directories."""
+ result = repository_ctx.execute([cc, "-E", "-xc++", "-", "-v"])
+ index1 = result.stderr.find(_INC_DIR_MARKER_BEGIN)
+ if index1 == -1:
+ return []
+ index1 = result.stderr.find("\n", index1)
+ if index1 == -1:
+ return []
+ index2 = result.stderr.rfind("\n ")
+ if index2 == -1 or index2 < index1:
+ return []
+ index2 = result.stderr.find("\n", index2 + 1)
+ if index2 == -1:
+ inc_dirs = result.stderr[index1 + 1:]
+ else:
+ inc_dirs = result.stderr[index1 + 1:index2].strip()
+
+ return [repository_ctx.path(_cxx_inc_convert(p))
+ for p in inc_dirs.split("\n")]
+
+# END cc_configure common functions (see TODO above).
+
+
+def _enable_cuda(repository_ctx):
+ if "TF_NEED_CUDA" in repository_ctx.os.environ:
+ enable_cuda = repository_ctx.os.environ["TF_NEED_CUDA"].strip()
+ return enable_cuda == "1"
+ return False
+
+
+def _cuda_toolkit_path(repository_ctx):
+ """Finds the cuda toolkit directory."""
+ cuda_toolkit_path = _DEFAULT_CUDA_TOOLKIT_PATH
+ if "CUDA_TOOLKIT_PATH" in repository_ctx.os.environ:
+ cuda_toolkit_path = repository_ctx.os.environ["CUDA_TOOLKIT_PATH"].strip()
+ if not repository_ctx.path(cuda_toolkit_path).exists:
+ fail("Cannot find cuda toolkit path.")
+ return cuda_toolkit_path
+
+
+def _cudnn_install_basedir(repository_ctx):
+ """Finds the cudnn install directory."""
+ cudnn_install_path = _DEFAULT_CUDNN_INSTALL_PATH
+ if "CUDNN_INSTALL_PATH" in repository_ctx.os.environ:
+ cudnn_install_path = repository_ctx.os.environ["CUDNN_INSTALL_PATH"].strip()
+ if not repository_ctx.path(cudnn_install_path).exists:
+ fail("Cannot find cudnn install path.")
+ return cudnn_install_path
+
+
+def _cuda_version(repository_ctx):
+ """Detects the cuda version."""
+ if "CUDA_VERSION" in repository_ctx.os.environ:
+ return repository_ctx.os.environ["CUDA_VERSION"].strip()
+ else:
+ return ""
+
+
+def _cudnn_version(repository_ctx):
+ """Detects the cudnn version."""
+ if "CUDNN_VERSION" in repository_ctx.os.environ:
+ return repository_ctx.os.environ["CUDNN_VERSION"].strip()
+ else:
+ return ""
+
+
+def _compute_capabilities(repository_ctx):
+ """Returns a list of strings representing cuda compute capabilities."""
+ if "CUDA_COMPUTE_CAPABILITIES" not in repository_ctx.os.environ:
+ return _DEFAULT_CUDA_COMPUTE_CAPABILITIES
+ capabilities_str = repository_ctx.os.environ["CUDA_COMPUTE_CAPABILITIES"]
+ capabilities = capabilities_str.split(",")
+ for capability in capabilities:
+ # Workaround for Skylark's lack of support for regex. This check should
+ # be equivalent to checking:
+ # if re.match("[0-9]+.[0-9]+", capability) == None:
+ parts = capability.split(".")
+ if len(parts) != 2 or not parts[0].isdigit() or not parts[1].isdigit():
+ fail("Invalid compute capability: %s" % capability)
+ return capabilities
+
+
+def _cpu_value(repository_ctx):
+ result = repository_ctx.execute(["uname", "-s"])
+ return result.stdout.strip()
+
+
+def _cuda_symlink_files(cpu_value, cuda_version, cudnn_version):
+ """Returns a struct containing platform-specific paths.
+
+ Args:
+ cpu_value: The string representing the host OS.
+ cuda_version: The cuda version as returned by _cuda_version
+ cudnn_version: The cudnn version as returned by _cudnn_version
+ """
+ cuda_ext = ".%s" % cuda_version if cuda_version else ""
+ cudnn_ext = ".%s" % cudnn_version if cudnn_version else ""
+ if cpu_value == "Linux":
+ return struct(
+ cuda_lib_path = "lib64",
+ cuda_rt_lib = "lib64/libcudart.so%s" % cuda_ext,
+ cuda_rt_lib_static = "lib64/libcudart_static.a",
+ cuda_blas_lib = "lib64/libcublas.so%s" % cuda_ext,
+ cuda_dnn_lib = "lib64/libcudnn.so%s" % cudnn_ext,
+ cuda_dnn_lib_alt = "libcudnn.so%s" % cudnn_ext,
+ cuda_rand_lib = "lib64/libcurand.so%s" % cuda_ext,
+ cuda_fft_lib = "lib64/libcufft.so%s" % cuda_ext,
+ cuda_cupti_lib = "extras/CUPTI/lib64/libcupti.so%s" % cuda_ext)
+ elif cpu_value == "Darwin":
+ return struct(
+ cuda_lib_path = "lib",
+ cuda_rt_lib = "lib/libcudart%s.dylib" % cuda_ext,
+ cuda_rt_lib_static = "lib/libcudart_static.a",
+ cuda_blas_lib = "lib/libcublas%s.dylib" % cuda_ext,
+ cuda_dnn_lib = "lib/libcudnn%s.dylib" % cudnn_ext,
+ cuda_dnn_lib_alt = "libcudnn%s.dylib" % cudnn_ext,
+ cuda_rand_lib = "lib/libcurand%s.dylib" % cuda_ext,
+ cuda_fft_lib = "lib/libcufft%s.dylib" % cuda_ext,
+ cuda_cupti_lib = "extras/CUPTI/lib/libcupti%s.dylib" % cuda_ext)
+ else:
+ fail("Not supported CPU value %s" % cpu_value)
+
+
+def _check_lib(repository_ctx, cuda_toolkit_path, cuda_lib):
+ """Checks if cuda_lib exists under cuda_toolkit_path or fail if it doesn't.
+
+ Args:
+ repository_ctx: The repository context.
+ cuda_toolkit_path: The cuda toolkit directory containing the cuda libraries.
+ cuda_lib: The library to look for under cuda_toolkit_path.
+ """
+ lib_path = cuda_toolkit_path + "/" + cuda_lib
+ if not repository_ctx.path(lib_path).exists:
+ fail("Cannot find %s" % lib_path)
+
+
+def _check_dir(repository_ctx, directory):
+ """Checks whether the directory exists and fail if it does not.
+
+ Args:
+ repository_ctx: The repository context.
+ directory: The directory to check the existence of.
+ """
+ if not repository_ctx.path(directory).exists:
+ fail("Cannot find dir: %s" % directory)
+
+
+def _find_cudnn_header_dir(repository_ctx, cudnn_install_basedir):
+ """Returns the path to the directory containing cudnn.h
+
+ Args:
+ repository_ctx: The repository context.
+ cudnn_install_basedir: The cudnn install directory as returned by
+ _cudnn_install_basedir.
+
+ Returns:
+ The path of the directory containing the cudnn header.
+ """
+ if repository_ctx.path(cudnn_install_basedir + "/cudnn.h").exists:
+ return cudnn_install_basedir
+ if repository_ctx.path(cudnn_install_basedir + "/include/cudnn.h").exists:
+ return cudnn_install_basedir + "/include"
+ if repository_ctx.path("/usr/include/cudnn.h").exists:
+ return "/usr/include"
+ fail("Cannot find cudnn.h under %s" % cudnn_install_basedir)
+
+
+def _find_cudnn_lib_path(repository_ctx, cudnn_install_basedir, symlink_files):
+ """Returns the path to the directory containing libcudnn
+
+ Args:
+ repository_ctx: The repository context.
+ cudnn_install_basedir: The cudnn install dir as returned by
+ _cudnn_install_basedir.
+ symlink_files: The symlink files as returned by _cuda_symlink_files.
+
+ Returns:
+ The path of the directory containing the cudnn libraries.
+ """
+ lib_dir = cudnn_install_basedir + "/" + symlink_files.cuda_dnn_lib
+ if repository_ctx.path(lib_dir).exists:
+ return lib_dir
+ alt_lib_dir = cudnn_install_basedir + "/" + symlink_files.cuda_dnn_lib_alt
+ if repository_ctx.path(alt_lib_dir).exists:
+ return alt_lib_dir
+
+ fail("Cannot find %s or %s under %s" %
+ (symlink_files.cuda_dnn_lib, symlink_files.cuda_dnn_lib_alt,
+ cudnn_install_basedir))
+
+
+def _tpl(repository_ctx, tpl, substitutions={}, out=None):
+ if not out:
+ out = tpl.replace(":", "/")
+ repository_ctx.template(
+ out,
+ Label("//third_party/gpus/%s.tpl" % tpl),
+ substitutions)
+
+
+def _file(repository_ctx, label):
+ repository_ctx.template(
+ label.replace(":", "/"),
+ Label("//third_party/gpus/%s.tpl" % label),
+ {})
+
+
+def _create_dummy_repository(repository_ctx):
+ cpu_value = _cpu_value(repository_ctx)
+ symlink_files = _cuda_symlink_files(cpu_value, _DEFAULT_CUDA_VERSION,
+ _DEFAULT_CUDNN_VERSION)
+
+ # Set up BUILD file for cuda/.
+ _file(repository_ctx, "cuda:BUILD")
+ _file(repository_ctx, "cuda:build_defs.bzl")
+ _tpl(repository_ctx, "cuda:platform.bzl",
+ {
+ "%{cuda_version}": _DEFAULT_CUDA_VERSION,
+ "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
+ "%{platform}": cpu_value,
+ })
+
+ # Create dummy files for the CUDA toolkit since they are still required by
+ # tensorflow/core/platform/default/build_config:cuda.
+ repository_ctx.file("cuda/include/cuda.h", "")
+ repository_ctx.file("cuda/include/cublas.h", "")
+ repository_ctx.file("cuda/include/cudnn.h", "")
+ repository_ctx.file("cuda/extras/CUPTI/include/cupti.h", "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_rt_lib, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_rt_lib_static, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_blas_lib, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_dnn_lib, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_rand_lib, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_fft_lib, "")
+ repository_ctx.file("cuda/%s" % symlink_files.cuda_cupti_lib, "")
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(repository_ctx, "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": _DEFAULT_CUDA_VERSION,
+ "%{cudnn_version}": _DEFAULT_CUDNN_VERSION,
+ "%{cuda_compute_capabilities}": ",".join([
+ "CudaVersion(\"%s\")" % c
+ for c in _DEFAULT_CUDA_COMPUTE_CAPABILITIES]),
+ })
+
+
+def _symlink_dir(repository_ctx, src_dir, dest_dir):
+ """Symlinks all the files in a directory.
+
+ Args:
+ repository_ctx: The repository context.
+ src_dir: The source directory.
+ dest_dir: The destination directory to create the symlinks in.
+ """
+ files = repository_ctx.path(src_dir).readdir()
+ for src_file in files:
+ repository_ctx.symlink(src_file, dest_dir + "/" + src_file.basename)
+
+
+def _create_cuda_repository(repository_ctx):
+ """Creates the repository containing files set up to build with CUDA."""
+ cuda_toolkit_path = _cuda_toolkit_path(repository_ctx)
+ cuda_version = _cuda_version(repository_ctx)
+ cudnn_install_basedir = _cudnn_install_basedir(repository_ctx)
+ cudnn_version = _cudnn_version(repository_ctx)
+ compute_capabilities = _compute_capabilities(repository_ctx)
+
+ cpu_value = _cpu_value(repository_ctx)
+ symlink_files = _cuda_symlink_files(cpu_value, cuda_version, cudnn_version)
+ _check_lib(repository_ctx, cuda_toolkit_path, symlink_files.cuda_rt_lib)
+ _check_lib(repository_ctx, cuda_toolkit_path, symlink_files.cuda_cupti_lib)
+ _check_dir(repository_ctx, cudnn_install_basedir)
+
+ cudnn_header_dir = _find_cudnn_header_dir(repository_ctx,
+ cudnn_install_basedir)
+ cudnn_lib_path = _find_cudnn_lib_path(repository_ctx, cudnn_install_basedir,
+ symlink_files)
+
+ # Set up symbolic links for the cuda toolkit. We link at the individual file
+ # level not at the directory level. This is because the external library may
+ # have a different file layout from our desired structure.
+ _symlink_dir(repository_ctx, cuda_toolkit_path + "/include", "cuda/include")
+ _symlink_dir(repository_ctx,
+ cuda_toolkit_path + "/" + symlink_files.cuda_lib_path,
+ "cuda/" + symlink_files.cuda_lib_path)
+ _symlink_dir(repository_ctx, cuda_toolkit_path + "/bin", "cuda/bin")
+ _symlink_dir(repository_ctx, cuda_toolkit_path + "/nvvm", "cuda/nvvm")
+ _symlink_dir(repository_ctx, cuda_toolkit_path + "/extras/CUPTI/include",
+ "cuda/extras/CUPTI/include")
+ repository_ctx.symlink(cuda_toolkit_path + "/" + symlink_files.cuda_cupti_lib,
+ "cuda/" + symlink_files.cuda_cupti_lib)
+
+ # Set up the symbolic links for cudnn if cudnn was was not installed to
+ # CUDA_TOOLKIT_PATH.
+ if not repository_ctx.path("cuda/include/cudnn.h").exists:
+ repository_ctx.symlink(cudnn_header_dir + "/cudnn.h",
+ "cuda/include/cudnn.h")
+ if not repository_ctx.path("cuda/" + symlink_files.cuda_dnn_lib).exists:
+ repository_ctx.symlink(cudnn_lib_path, "cuda/" + symlink_files.cuda_dnn_lib)
+
+ # Set up BUILD file for cuda/
+ _file(repository_ctx, "cuda:BUILD")
+ _file(repository_ctx, "cuda:build_defs.bzl")
+ _tpl(repository_ctx, "cuda:platform.bzl",
+ {
+ "%{cuda_version}": cuda_version,
+ "%{cudnn_version}": cudnn_version,
+ "%{platform}": cpu_value,
+ })
+
+ # Set up crosstool/
+ _file(repository_ctx, "crosstool:BUILD")
+ _tpl(repository_ctx, "crosstool:CROSSTOOL",
+ {
+ "%{cuda_version}": ("-%s" % cuda_version) if cuda_version else "",
+ })
+ _tpl(repository_ctx,
+ "crosstool:clang/bin/crosstool_wrapper_driver_is_not_gcc",
+ {
+ "%{cpu_compiler}": str(find_cc(repository_ctx)),
+ "%{gcc_host_compiler_path}": str(find_cc(repository_ctx)),
+ "%{cuda_compute_capabilities}": ", ".join(
+ ["\"%s\"" % c for c in compute_capabilities]),
+ })
+
+ # Set up cuda_config.h, which is used by
+ # tensorflow/stream_executor/dso_loader.cc.
+ _tpl(repository_ctx, "cuda:cuda_config.h",
+ {
+ "%{cuda_version}": cuda_version,
+ "%{cudnn_version}": cudnn_version,
+ "%{cuda_compute_capabilities}": ",".join(
+ ["CudaVersion(\"%s\")" % c for c in compute_capabilities]),
+ })
+
+
+def _cuda_autoconf_impl(repository_ctx):
+ """Implementation of the cuda_autoconf repository rule."""
+ if not _enable_cuda(repository_ctx):
+ _create_dummy_repository(repository_ctx)
+ else:
+ _create_cuda_repository(repository_ctx)
+
+
+cuda_configure = repository_rule(
+ implementation = _cuda_autoconf_impl,
+ local = True,
+)
+"""Detects and configures the local CUDA toolchain.
+
+Add the following to your WORKSPACE FILE:
+
+```python
+cuda_configure(name = "local_config_cuda")
+```
+
+Args:
+ name: A unique name for this workspace rule.
+"""
diff --git a/tools/bazel.rc.template b/tools/bazel.rc.template
index 02856822c9..9a69cac1f6 100644
--- a/tools/bazel.rc.template
+++ b/tools/bazel.rc.template
@@ -1,4 +1,4 @@
-build:cuda --crosstool_top=//third_party/gpus/crosstool
+build:cuda --crosstool_top=@local_config_cuda//crosstool
build:cuda --define=using_cuda=true --define=using_cuda_nvcc=true
build --force_python=py$PYTHON_MAJOR_VERSION