diff options
author | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-01-29 11:15:32 -0800 |
---|---|---|
committer | gracehoney <31743510+aaroey@users.noreply.github.com> | 2018-01-29 11:15:32 -0800 |
commit | 98ffa7b2bb5b273b4ae2b052e82fb0d8c054d96b (patch) | |
tree | b5f25ab2c3f86adb339485f2be2c65a91f5f6def | |
parent | 75adab6104362d71ce28b0269bf31fd30471b1b6 (diff) |
Fix compilation, there are still linking issues which will be fixed in a followup commit.
-rw-r--r-- | configure.py | 124 | ||||
-rw-r--r-- | tensorflow/contrib/BUILD | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 111 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 4 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 2 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/kernels/trt_engine_op.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/log/trt_logger.h | 3 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc | 5 | ||||
-rw-r--r-- | tensorflow/tools/pip_package/BUILD | 4 | ||||
-rw-r--r-- | third_party/tensorrt/BUILD.tpl | 3 | ||||
-rw-r--r-- | third_party/tensorrt/build_defs.bzl | 85 | ||||
-rw-r--r-- | third_party/tensorrt/build_defs.bzl.tpl | 3 |
12 files changed, 84 insertions, 267 deletions
diff --git a/configure.py b/configure.py index b621b1bc1b..1567ed697f 100644 --- a/configure.py +++ b/configure.py @@ -37,7 +37,6 @@ _TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), _TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'WORKSPACE') _DEFAULT_CUDA_VERSION = '9.0' -_DEFAULT_TENSORRT_VERSION = '4' _DEFAULT_CUDNN_VERSION = '7' _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2' _DEFAULT_CUDA_PATH = '/usr/local/cuda' @@ -384,12 +383,13 @@ def set_build_var(environ_cp, var_name, query_item, option_name, var = str(int(get_var(environ_cp, var_name, query_item, enabled_by_default))) environ_cp[var_name] = var - # TODO(mikecase): Migrate all users of configure.py to use --config Bazel - # options and not to set build configs through environment variables. - if var=='1': - setting='true' - confname=":%s"%(bazel_config_name) if bazel_config_name is not None else "" - write_to_bazelrc('build%s --define %s=%s' % (confname,option_name,setting)) + if var == '1': + write_to_bazelrc('build --define %s=true' % option_name) + elif bazel_config_name is not None: + # TODO(mikecase): Migrate all users of configure.py to use --config Bazel + # options and not to set build configs through environment variables. + write_to_bazelrc('build:%s --define %s=true' + % (bazel_config_name, option_name)) def set_action_env_var(environ_cp, @@ -439,6 +439,7 @@ def convert_version_to_int(version): for seg in version_segments: if not seg.isdigit(): return None + version_str = ''.join(['%03d' % int(seg) for seg in version_segments]) return int(version_str) @@ -1169,108 +1170,6 @@ def set_other_cuda_vars(environ_cp): write_to_bazelrc('test --config=cuda') -def set_tf_trt_version(environ_cp): - """Set TENSORRT_INSTALL_PATH and TF_TENSORRT_VERSION.""" - ask_trt_version = ( - 'Please specify the TensorRT (libnvinfer) version you want to use. ' - '[Leave empty to default to libnvinfer %s]: ') % _DEFAULT_TENSORRT_VERSION - - while True: - tf_trt_version = get_from_env_or_user_or_default( - environ_cp, 'TF_TENSORRT_VERSION', ask_trt_version, - _DEFAULT_TENSORRT_VERSION) - # if library version is passed and known - default_trt_path = environ_cp.get('TENSORRT_INSTALL_PATH',_DEFAULT_TENSORRT_PATH_LINUX) - ask_trt_path = (r'Please specify the location where libnvinfer %s library is ' - 'installed. Refer to README.md for more details. [Default' - ' is %s]:') % (tf_trt_version, default_trt_path) - trt_install_path = get_from_env_or_user_or_default( - environ_cp, 'TENSORRT_INSTALL_PATH', ask_trt_path, default_trt_path) - - # Result returned from "read" will be used unexpanded. That make "~" - # unusable. Going through one more level of expansion to handle that. - trt_install_path = os.path.realpath( - os.path.expanduser(trt_install_path)) - # Simple function to search for libnvinfer in install path - # it will find all libnvinfer.so* in user defined install path - # and lib64 subdirectory and return absolute paths - def find_libs(search_path): - fl=set() - if os.path.exists(search_path) and os.path.isdir(search_path): - fl.update([os.path.realpath(os.path.join(search_path,x)) \ - for x in os.listdir(search_path) if 'libnvinfer.so' in x]) - return fl - possible_files=find_libs(trt_install_path) - possible_files.update(find_libs(os.path.join(trt_install_path,'lib64'))) - if is_linux(): - cudnnpatt=re.compile(".*libcudnn.so\.?(.*) =>.*$") - cudapatt =re.compile(".*libcudart.so\.?(.*) =>.*$") - def is_compatible(lib,cudaver,cudnnver): - ldd_bin=which('ldd') or '/usr/bin/ldd' - ldd_out=run_shell([ldd_bin,lib]).split(os.linesep) - for l in ldd_out: - if 'libcudnn.so' in l: - cudnn=cudnnpatt.search(l) - elif 'libcudart.so' in l: - cudart=cudapatt.search(l) - if cudnn: - cudnn=convert_version_to_int(cudnn.group(1)) if len(cudnn.group(1)) else 0 - if cudart: - cudart=convert_version_to_int(cudart.group(1)) if len(cudart.group(1)) else 0 - return (cudnn==cudnnver) and (cudart==cudaver) - cudaver=convert_version_to_int(environ_cp['TF_CUDA_VERSION']) - cudnnver=convert_version_to_int(environ_cp['TF_CUDNN_VERSION']) - valid_libs=[] - vfinder=re.compile('.*libnvinfer.so.?(.*)$') - highest_ver=[0,None,None] - - for l in possible_files: - if is_compatible(l,cudaver,cudnnver): - valid_libs.append(l) - vstr=vfinder.search(l).group(1) - currver=convert_version_to_int(vstr) if len(vstr) else 0 - if currver > highest_ver[0]: - highest_ver= [currver,vstr,l] - if highest_ver[1] is not None: - trt_install_path=os.path.dirname(highest_ver[2]) - tf_trt_version=highest_ver[1] - break - ldconfig_bin = which('ldconfig') or '/sbin/ldconfig' - libnvinfer_path_from_ldconfig = run_shell([ldconfig_bin, '-p']) - libnvinfer_path_from_ldconfig = re.search('.*libnvinfer.so.* => (.*)', - libnvinfer_path_from_ldconfig) - if libnvinfer_path_from_ldconfig: - libnvinfer_path_from_ldconfig = libnvinfer_path_from_ldconfig.group(1) - if os.path.exists('%s.%s' % (libnvinfer_path_from_ldconfig, - tf_trt_version)): - trt_install_path = os.path.dirname(libnvinfer_path_from_ldconfig) - break - - # Reset and Retry - if len(possible_files): - print( - 'Invalid path to TensorRT %s. libnvinfer.so* files found are for incompatible cuda versions ' - % tf_trt_version) - print(trt_install_path) - print(os.path.join(trt_install_path,'lib64')) - else: - print( - 'Invalid path to TensorRT %s. No libnvinfer.so* files found in ' - 'found:' % tf_trt_version) - print(trt_install_path) - print(os.path.join(trt_install_path,'lib64')) - if is_linux(): - print('%s.%s' % (libnvinfer_path_from_ldconfig, tf_trt_version)) - - environ_cp['TF_TENSORRT_VERSION'] = '' - - # Set TENSORRT_INSTALL_PATH and TENSORRT_CUDNN_VERSION - environ_cp['TENSORRT_INSTALL_PATH'] = trt_install_path - write_action_env_to_bazelrc('TENSORRT_INSTALL_PATH', trt_install_path) - environ_cp['TF_TENSORRT_VERSION'] = tf_trt_version - write_action_env_to_bazelrc('TF_TENSORRT_VERSION', tf_trt_version) - write_to_bazelrc('build:tensorrt --define using_tensorrt=true') - def set_host_cxx_compiler(environ_cp): """Set HOST_CXX_COMPILER.""" default_cxx_host_compiler = which('g++') or '' @@ -1455,6 +1354,7 @@ def main(): environ_cp['TF_NEED_GCP'] = '0' environ_cp['TF_NEED_HDFS'] = '0' environ_cp['TF_NEED_JEMALLOC'] = '0' + environ_cp['TF_NEED_KAFKA'] = '0' environ_cp['TF_NEED_OPENCL_SYCL'] = '0' environ_cp['TF_NEED_COMPUTECPP'] = '0' environ_cp['TF_NEED_OPENCL'] = '0' @@ -1473,6 +1373,8 @@ def main(): 'with_hdfs_support', True, 'hdfs') set_build_var(environ_cp, 'TF_NEED_S3', 'Amazon S3 File System', 'with_s3_support', True, 's3') + set_build_var(environ_cp, 'TF_NEED_KAFKA', 'Apache Kafka Platform', + 'with_kafka_support', False, 'kafka') set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support', False, 'xla') set_build_var(environ_cp, 'TF_NEED_GDR', 'GDR', 'with_gdr_support', @@ -1520,10 +1422,6 @@ def main(): if not is_windows(): set_gcc_host_compiler_path(environ_cp) set_other_cuda_vars(environ_cp) - # enable tensorrt if desired. Disabled on non-linux - set_action_env_var(environ_cp, 'TF_NEED_TENSORRT', 'TensorRT', False) - if environ_cp.get('TF_NEED_TENSORRT') == '1': - set_tf_trt_version(environ_cp) set_build_var(environ_cp, 'TF_NEED_MPI', 'MPI', 'with_mpi_support', False) if environ_cp.get('TF_NEED_MPI') == '1': diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD index d4c0660285..f745c175b1 100644 --- a/tensorflow/contrib/BUILD +++ b/tensorflow/contrib/BUILD @@ -7,7 +7,7 @@ package(default_visibility = ["//tensorflow:__subpackages__"]) load("//third_party/mpi:mpi.bzl", "if_mpi") load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_trt") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") py_library( name = "contrib_py", @@ -106,7 +106,7 @@ py_library( "//tensorflow/contrib/util:util_py", "//tensorflow/python:util", ] + if_mpi(["//tensorflow/contrib/mpi_collectives:mpi_ops_py"]) - + if_trt(["//tensorflow/contrib/tensorrt:init_py"]), + + if_tensorrt(["//tensorflow/contrib/tensorrt:init_py"]), ) diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index eeb308fee8..b179e815c8 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -49,34 +49,34 @@ tf_custom_op_library( name = "python/ops/_trt_engine_op.so", srcs = [ "kernels/trt_engine_op.cc", - "ops/trt_engine_op.cc", "kernels/trt_engine_op.h", + "ops/trt_engine_op.cc", ], gpu_srcs = [], deps = [ - "@local_config_tensorrt//:tensorrt", ":trt_shape_function", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core/kernels:bounds_check_lib", "//tensorflow/core/kernels:ops_util_hdrs", + "@local_config_tensorrt//:nv_infer", ], ) cc_library( name = "trt_shape_function", - srcs=[ + srcs = [ "shape_fn/trt_shfn.cc", ], - hdrs=["shape_fn/trt_shfn.h"], - copts=tf_copts(), - deps=[ + hdrs = ["shape_fn/trt_shfn.h"], + copts = tf_copts(), + deps = [ ":trt_logging", + "//tensorflow/core:framework_headers_lib", "//third_party/eigen3", - "@local_config_tensorrt//:tensorrt", - "@protobuf_archive//:protobuf", + "@local_config_tensorrt//:nv_infer", "@nsync//:nsync_headers", - "//tensorflow/core:framework_headers_lib", - ] + "@protobuf_archive//:protobuf", + ], ) tf_kernel_library( @@ -84,7 +84,7 @@ tf_kernel_library( srcs = [ "kernels/trt_engine_op.cc", ], - hdrs=[ + hdrs = [ "kernels/trt_engine_op.h", ], gpu_srcs = [ @@ -93,37 +93,37 @@ tf_kernel_library( ":trt_logging", ":trt_shape_function", "//tensorflow/core:framework", + "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib", + "//tensorflow/core:lib_proto_parsing", "//third_party/eigen3", - "//tensorflow/core:gpu_headers_lib", - "@local_config_tensorrt//:tensorrt", - "//tensorflow/core:lib_proto_parsing", + "@local_config_tensorrt//:nv_infer", ], - alwayslink=1, + alwayslink = 1, ) tf_gen_op_libs( - op_lib_names = [ - "trt_engine_op", - ], - deps=[ - "@local_config_tensorrt//:tensorrt", - ] + op_lib_names = [ + "trt_engine_op", + ], + deps = [ + "@local_config_tensorrt//:nv_infer", + ], ) cc_library( - name="trt_logging", + name = "trt_logging", srcs = [ - "log/trt_logger.cc", + "log/trt_logger.cc", ], - hdrs=[ - "log/trt_logger.h", + hdrs = [ + "log/trt_logger.h", ], - deps=[ - "@local_config_tensorrt//:tensorrt", + visibility = ["//visibility:public"], + deps = [ "//tensorflow/core:lib_proto_parsing", + "@local_config_tensorrt//:nv_infer", ], - visibility = ["//visibility:public"], ) tf_gen_op_wrapper_py( @@ -137,8 +137,9 @@ tf_gen_op_wrapper_py( tf_custom_op_py_library( name = "trt_engine_op_loader", srcs = ["python/ops/trt_engine_op.py"], - dso = [":python/ops/_trt_engine_op.so", - "@local_config_tensorrt//:tensorrt", + dso = [ + ":python/ops/_trt_engine_op.so", + "@local_config_tensorrt//:nv_infer", ], srcs_version = "PY2AND3", deps = [ @@ -155,33 +156,33 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":trt_convert_py", ":trt_ops_py", - ":trt_convert_py", ], ) py_library( - name="trt_ops_py", + name = "trt_ops_py", srcs_version = "PY2AND3", - deps=[":trt_engine_op", - ":trt_engine_op_loader", + deps = [ + ":trt_engine_op", + ":trt_engine_op_loader", ], - ) py_library( - name="trt_convert_py", - srcs=["python/trt_convert.py"], + name = "trt_convert_py", + srcs = ["python/trt_convert.py"], srcs_version = "PY2AND3", - deps=[ - ":wrap_conversion" + deps = [ + ":wrap_conversion", ], ) tf_py_wrap_cc( - name="wrap_conversion", - srcs=["trt_conversion.i"], - deps=[ + name = "wrap_conversion", + srcs = ["trt_conversion.i"], + deps = [ ":trt_conversion", "//tensorflow/core:framework_lite", "//util/python:python_headers", @@ -189,20 +190,20 @@ tf_py_wrap_cc( ) cc_library( - name= "trt_conversion", - srcs=[ - "convert/convert_nodes.cc", + name = "trt_conversion", + srcs = [ "convert/convert_graph.cc", + "convert/convert_nodes.cc", "segment/segment.cc", ], - hdrs=[ - "convert/convert_nodes.h", + hdrs = [ "convert/convert_graph.h", + "convert/convert_nodes.h", "segment/segment.h", "segment/union_find.h", ], - deps=[ - "@local_config_tensorrt//:tensorrt", + deps = [ + "@local_config_tensorrt//:nv_infer", "@protobuf_archive//:protobuf_headers", "@nsync//:nsync_headers", ":trt_logging", @@ -225,7 +226,7 @@ tf_custom_op_library( "ops/tensorrt_ops.cc", ], deps = [ - "@local_config_tensorrt//:tensorrt", + "@local_config_tensorrt//:nv_infer", ], ) @@ -236,16 +237,16 @@ cc_library( "segment/segment.cc", ], hdrs = [ - "segment/union_find.h", "segment/segment.h", + "segment/union_find.h", ], + linkstatic = 1, deps = [ - "@protobuf_archive//:protobuf_headers", "//tensorflow/core:core_cpu", "//tensorflow/core:lib_proto_parsing", "//third_party/eigen3", + "@protobuf_archive//:protobuf_headers", ], - linkstatic = 1, ) tf_cc_test( @@ -265,13 +266,13 @@ tf_cc_test( filegroup( name = "cppfiles", srcs = glob(["**/*.cc"]), - visibility=["//visibility:private"], + visibility = ["//visibility:private"], ) filegroup( name = "headers", srcs = glob(["**/*.h"]), - visibility=["//visibility:private"], + visibility = ["//visibility:private"], ) filegroup( diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index fa5ed28060..3ade4ec356 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -25,8 +25,6 @@ limitations under the License. #include <map> #include <utility> -#include "NvInfer.h" - #include "tensorflow/contrib/tensorrt/convert/convert_nodes.h" #include "tensorflow/contrib/tensorrt/segment/segment.h" #include "tensorflow/core/framework/graph.pb.h" @@ -46,8 +44,8 @@ limitations under the License. #include "tensorflow/core/protobuf/device_properties.pb.h" #include "tensorflow/core/grappler/grappler_item.h" #include "tensorflow/core/grappler/utils.h" - #include "tensorflow/core/grappler/costs/graph_properties.h" +#include "tensorrt/include/NvInfer.h" //------------------------------------------------------------------------------ namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index 2fd7f659d5..19242cd944 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -26,7 +26,6 @@ limitations under the License. #include <unordered_map> #include <utility> #include <vector> -#include "NvInfer.h" #include "tensorflow/contrib/tensorrt/log/trt_logger.h" #include "tensorflow/core/framework/graph.pb.h" @@ -39,6 +38,7 @@ limitations under the License. #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/platform/logging.h" +#include "tensorrt/include/NvInfer.h" #define _TF_LOG_DEBUG ::tensorflow::internal::LogMessage(__FILE__, __LINE__, -1) // Check if the types are equal. Cast to int first so that failure log message diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h index 2f2d453dda..ac188addd7 100644 --- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h +++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h @@ -16,11 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ #define TENSORFLOW_CONTRIB_TENSORRT_KERNELS_TRT_ENGINE_OP_H_ -#include <NvInfer.h> #include <cuda_runtime_api.h> #include <memory> #include <string> #include <vector> + +#include "tensorrt/include/NvInfer.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/op_kernel.h" diff --git a/tensorflow/contrib/tensorrt/log/trt_logger.h b/tensorflow/contrib/tensorrt/log/trt_logger.h index 3a3a29516a..8737813bed 100644 --- a/tensorflow/contrib/tensorrt/log/trt_logger.h +++ b/tensorflow/contrib/tensorrt/log/trt_logger.h @@ -18,9 +18,10 @@ limitations under the License. #define TENSORFLOW_CONTRIB_TENSORRT_LOG_TRT_LOGGER_H_ // Use TF logging f -#include <NvInfer.h> #include <string> +#include "tensorrt/include/NvInfer.h" + //------------------------------------------------------------------------------ namespace tensorflow { diff --git a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc index 16a4dc2134..7951397b7e 100644 --- a/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc +++ b/tensorflow/contrib/tensorrt/shape_fn/trt_shfn.cc @@ -13,11 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" #include <string> #include <vector> -#include "NvInfer.h" + #include "tensorflow/contrib/tensorrt/log/trt_logger.h" +#include "tensorflow/contrib/tensorrt/shape_fn/trt_shfn.h" +#include "tensorrt/include/NvInfer.h" namespace tensorflow { namespace shape_inference { diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD index d864d09d8f..63bceaa8a6 100644 --- a/tensorflow/tools/pip_package/BUILD +++ b/tensorflow/tools/pip_package/BUILD @@ -11,7 +11,7 @@ load( ) load("//third_party/mkl:build_defs.bzl", "if_mkl") load("//tensorflow:tensorflow.bzl", "if_cuda") -load("@local_config_tensorrt//:build_defs.bzl", "if_trt") +load("@local_config_tensorrt//:build_defs.bzl", "if_tensorrt") load("//tensorflow/core:platform/default/build_config_root.bzl", "tf_additional_license_deps") # This returns a list of headers of all public header libraries (e.g., @@ -183,7 +183,7 @@ sh_binary( "//tensorflow/tools/dist_test/server:grpc_tensorflow_server", ], }) + if_mkl(["//third_party/mkl:intel_binary_blob"]) - + if_trt(["//tensorflow/contrib/tensorrt:init_py"]), + + if_tensorrt(["//tensorflow/contrib/tensorrt:init_py"]), ) # A genrule for generating a marker file for the pip package on Windows diff --git a/third_party/tensorrt/BUILD.tpl b/third_party/tensorrt/BUILD.tpl index 6cb7db7e90..99c0e89498 100644 --- a/third_party/tensorrt/BUILD.tpl +++ b/third_party/tensorrt/BUILD.tpl @@ -66,4 +66,5 @@ cc_library( visibility = ["//visibility:public"], ) -%{tensorrt_genrules}
\ No newline at end of file +%{tensorrt_genrules} + diff --git a/third_party/tensorrt/build_defs.bzl b/third_party/tensorrt/build_defs.bzl deleted file mode 100644 index 392c5e0621..0000000000 --- a/third_party/tensorrt/build_defs.bzl +++ /dev/null @@ -1,85 +0,0 @@ -# -*- python -*- -""" - add a repo_generator rule for tensorrt - -""" - -_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH" -_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION" - -def _is_trt_enabled(repo_ctx): - if "TF_NEED_TENSORRT" in repo_ctx.os.environ: - enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip() - return enable_trt == "1" - return False - -def _dummy_repo(repo_ctx): - - repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"), - {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""}, - False) - repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"), - {"%{trt_configured}":"False"},False) - repo_ctx.file("include/NvUtils.h","",False) - repo_ctx.file("include/NvInfer.h","",False) - -def _trt_repo_impl(repo_ctx): - """ - Implements local_config_tensorrt - """ - - if not _is_trt_enabled(repo_ctx): - _dummy_repo(repo_ctx) - return - trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH] - trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION] -# if deb installation -# once a standardized installation between tar and deb -# is done, we don't need this - if trt_libdir == '/usr/lib/x86_64-linux-gnu': - incPath='/usr/include/x86_64-linux-gnu' - incname='/usr/include/x86_64-linux-gnu/NvInfer.h' - else: - incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath) - incname=incPath+'/NvInfer.h' - if len(trt_ver)>0: - origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver) - else: - origLib="%s/libnvinfer.so"%trt_libdir - objdump=repo_ctx.which("objdump") - if objdump == None: - if len(trt_ver)>0: - targetlib="lib/libnvinfer.so.%s"%(trt_ver[0]) - else: - targetlib="lib/libnvinfer.so" - else: - soname=repo_ctx.execute([objdump,"-p",origLib]) - for l in soname.stdout.splitlines(): - if "SONAME" in l: - lib=l.strip().split(" ")[-1] - targetlib="lib/%s"%(lib) - - if len(trt_ver)>0: - repo_ctx.symlink(origLib,targetlib) - else: - repo_ctx.symlink(origLib,targetlib) - grule=('genrule(\n name = "trtlinks",\n'+ - ' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib + - ' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) + - '&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) + - '&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath)) - repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"), - {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule}, - False) - repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"), - {"%{trt_configured}":"True"},False) - -trt_repository=repository_rule( - implementation= _trt_repo_impl, - local=True, - environ=[ - "TF_NEED_TENSORRT", - _TF_TENSORRT_VERSION, - _TENSORRT_INSTALLATION_PATH, - ], - ) diff --git a/third_party/tensorrt/build_defs.bzl.tpl b/third_party/tensorrt/build_defs.bzl.tpl index 8a89b59bc8..f5348a7c06 100644 --- a/third_party/tensorrt/build_defs.bzl.tpl +++ b/third_party/tensorrt/build_defs.bzl.tpl @@ -4,4 +4,5 @@ def if_tensorrt(if_true, if_false=[]): """Tests whether TensorRT was enabled during the configure process.""" if %{tensorrt_is_configured}: return if_true - return if_false
\ No newline at end of file + return if_false + |