aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--configure.py5
-rw-r--r--tensorflow/contrib/tensorrt/BUILD1
-rw-r--r--tensorflow/contrib/tensorrt/log/trt_logger.cc7
-rw-r--r--third_party/gpus/cuda_configure.bzl23
-rw-r--r--third_party/tensorrt/BUILD.tpl31
-rw-r--r--third_party/tensorrt/build_defs.bzl.tpl1
-rw-r--r--third_party/tensorrt/tensorrt_configure.bzl4
7 files changed, 25 insertions, 47 deletions
diff --git a/configure.py b/configure.py
index 1567ed697f..f136577005 100644
--- a/configure.py
+++ b/configure.py
@@ -1422,6 +1422,11 @@ def main():
if not is_windows():
set_gcc_host_compiler_path(environ_cp)
set_other_cuda_vars(environ_cp)
+ # superceeded by call above
+ # enable tensorrt if desired. Disabled on non-linux
+ # set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False)
+ # if environ_cp.get('TF_NEED_TENSORRT') == '1':
+ # set_tf_trt_version(environ_cp)
set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False)
if environ_cp.get('TF_NEED_MPI') == '1':
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index b179e815c8..57c106c812 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -77,6 +77,7 @@ cc_library(
"@nsync//:nsync_headers",
"@protobuf_archive//:protobuf",
],
+ visibility = ["//visibility:public"], #for c/c++ linking
)
tf_kernel_library(
diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.cc b/tensorflow/contrib/tensorrt/log/trt_logger.cc
index 3910ed1201..d8a0310828 100644
--- a/tensorflow/contrib/tensorrt/log/trt_logger.cc
+++ b/tensorflow/contrib/tensorrt/log/trt_logger.cc
@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#if GOOGLE_CUDA
-#if GOOGLE_TENSORRT
#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
// Use TF logging for TensorRT informations
@@ -30,7 +28,7 @@ void Logger::log(Severity severity, const char* msg) {
// suppress info-level messages
switch (severity) {
case Severity::kINFO: { // mark TRT info messages as debug!
- LOG(DEBUG) << msg;
+ VLOG(-1) << msg;
break;
}
case Severity::kWARNING: {
@@ -57,6 +55,3 @@ void Logger::log(Severity severity, const char* msg) {
} // namespace tensorrt
} // namespace tensorflow
-
-#endif // GOOGLE_TENSORRT
-#endif // GOOGLE_CUDA
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index 8e1dd8a54f..7504154735 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -358,8 +358,8 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
if not h_path.exists:
auto_configure_fail("Cannot find %s at %s" % (header_file, str(h_path)))
result = repository_ctx.execute(
- # Grep one more lines as some #defines are splitted into two lines.
- ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
+ # Grep one more lines as some #defines are splitted into two lines.
+ ["grep", "--color=never", "-A1", "-E", define, str(h_path)])
if result.stderr:
auto_configure_fail("Error reading %s: %s" % (str(h_path), result.stderr))
@@ -367,11 +367,20 @@ def find_cuda_define(repository_ctx, header_dir, header_file, define):
if result.stdout.find(define) == -1:
auto_configure_fail("Cannot find line containing '%s' in %s" %
(define, h_path))
- version = result.stdout
- # Remove the new line and '\' character if any.
- version = version.replace("\\", " ")
- version = version.replace("\n", " ")
- version = version.replace(define, "").lstrip()
+ #split results to lines
+ lines=result.stdout.split('\n')
+ lenLines=len(lines)
+ for l in range(lenLines):
+ line=lines[l]
+ if define in line: # find the line with define
+ version=line
+ if l != lenLines-1 and line[-1] == '\\': # add next line, if multiline
+ version=version[:-1]+lines[l+1]
+ break
+ #remove any comments
+ version = version.split("//")[0]
+ # remove define name
+ version = version.replace(define, "").strip()
# Remove the code after the version number.
version_end = version.find(" ")
if version_end != -1:
diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl
index 99c0e89498..dc7fe0c8c8 100644
--- a/third_party/tensorrt/BUILD.tpl
+++ b/third_party/tensorrt/BUILD.tpl
@@ -34,37 +34,6 @@ cc_library(
visibility = ["//visibility:public"],
)
-cc_library(
- name = "nv_infer_plugin",
- srcs = [%{nv_infer_plugin}],
- data = [%{nv_infer_plugin}],
- includes = [
- "include",
- ],
- copts= cuda_default_copts(),
- deps = [
- "@local_config_cuda//cuda:cuda",
- ":nv_infer",
- ":tensorrt_headers",
- ],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "nv_parsers",
- srcs = [%{nv_parsers}],
- data = [%{nv_parsers}],
- includes = [
- "include",
- ],
- copts= cuda_default_copts(),
- deps = [
- ":tensorrt_headers",
- ],
- linkstatic = 1,
- visibility = ["//visibility:public"],
-)
%{tensorrt_genrules}
diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl
index f5348a7c06..0dc3a7ba2d 100644
--- a/third_party/tensorrt/build_defs.bzl.tpl
+++ b/third_party/tensorrt/build_defs.bzl.tpl
@@ -5,4 +5,3 @@ def if_tensorrt(if_true, if_false=[]):
if %{tensorrt_is_configured}:
return if_true
return if_false
-
diff --git a/third_party/tensorrt/tensorrt_configure.bzl b/third_party/tensorrt/tensorrt_configure.bzl
index 8aa0f28f39..4a1441500a 100644
--- a/third_party/tensorrt/tensorrt_configure.bzl
+++ b/third_party/tensorrt/tensorrt_configure.bzl
@@ -19,9 +19,9 @@ load(
_TENSORRT_INSTALL_PATH = "TENSORRT_INSTALL_PATH"
_TF_TENSORRT_VERSION = "TF_TENSORRT_VERSION"
-_TF_TENSORRT_LIBS = ["nvinfer", "nvinfer_plugin", "nvparsers"]
+_TF_TENSORRT_LIBS = ["nvinfer"]
_TF_TENSORRT_HEADERS = [
- "NvInfer.h", "NvInferPlugin.h", "NvCaffeParser.h", "NvUffParser.h",
+ "NvInfer.h",
"NvUtils.h"
]