aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/tensorflow.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/tensorflow.bzl')
-rw-r--r--tensorflow/tensorflow.bzl16
1 files changed, 12 insertions, 4 deletions
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 383c97344a..7fe9c98726 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -11,6 +11,10 @@ load(
"if_static",
)
load(
+ "@local_config_tensorrt//:build_defs.bzl",
+ "if_tensorrt",
+)
+load(
"@local_config_cuda//cuda:build_defs.bzl",
"if_cuda",
"cuda_default_copts",
@@ -197,6 +201,7 @@ def tf_copts(android_optimization_level_override="-O2", is_external=False):
"-fno-exceptions",
"-ftemplate-depth=900"])
+ if_cuda(["-DGOOGLE_CUDA=1"])
+ + if_tensorrt(["-DGOOGLE_TENSORRT=1"])
+ if_mkl(["-DINTEL_MKL=1", "-DEIGEN_USE_VML", "-fopenmp",])
+ if_android_arm(["-mfpu=neon"])
+ if_linux_x86_64(["-msse3"])
@@ -861,9 +866,11 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
When the library is built with --config=cuda:
- - both deps and cuda_deps are used as dependencies
- - the cuda runtime is added as a dependency (if necessary)
- - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts
+ - Both deps and cuda_deps are used as dependencies.
+ - The cuda runtime is added as a dependency (if necessary).
+ - The library additionally passes -DGOOGLE_CUDA=1 to the list of copts.
+ - In addition, when the library is also built with TensorRT enabled, it
+ additionally passes -DGOOGLE_TENSORRT=1 to the list of copts.
Args:
- cuda_deps: BUILD dependencies which will be linked if and only if:
@@ -882,7 +889,8 @@ def tf_cuda_library(deps=None, cuda_deps=None, copts=tf_copts(), **kwargs):
clean_dep("//tensorflow/core:cuda"),
"@local_config_cuda//cuda:cuda_headers"
]),
- copts=copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]),
+ copts=(copts + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1"]) +
+ if_tensorrt(["-DGOOGLE_TENSORRT=1"])),
**kwargs)
register_extension_info(