aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/tensorrt/build_defs.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/tensorrt/build_defs.bzl')
-rw-r--r--third_party/tensorrt/build_defs.bzl85
1 files changed, 85 insertions, 0 deletions
diff --git a/third_party/tensorrt/build_defs.bzl b/third_party/tensorrt/build_defs.bzl
new file mode 100644
index 0000000000..392c5e0621
--- /dev/null
+++ b/third_party/tensorrt/build_defs.bzl
@@ -0,0 +1,85 @@
+# -*- python -*-
+"""
+ add a repo_generator rule for tensorrt
+
+"""
+
+_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
+_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"
+
+def _is_trt_enabled(repo_ctx):
+ if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
+ enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
+ return enable_trt == "1"
+ return False
+
+def _dummy_repo(repo_ctx):
+
+ repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+ {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
+ False)
+ repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+ {"%{trt_configured}":"False"},False)
+ repo_ctx.file("include/NvUtils.h","",False)
+ repo_ctx.file("include/NvInfer.h","",False)
+
+def _trt_repo_impl(repo_ctx):
+ """
+ Implements local_config_tensorrt
+ """
+
+ if not _is_trt_enabled(repo_ctx):
+ _dummy_repo(repo_ctx)
+ return
+ trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
+ trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
+# if deb installation
+# once a standardized installation between tar and deb
+# is done, we don't need this
+ if trt_libdir == '/usr/lib/x86_64-linux-gnu':
+ incPath='/usr/include/x86_64-linux-gnu'
+ incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
+ else:
+ incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
+ incname=incPath+'/NvInfer.h'
+ if len(trt_ver)>0:
+ origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
+ else:
+ origLib="%s/libnvinfer.so"%trt_libdir
+ objdump=repo_ctx.which("objdump")
+ if objdump == None:
+ if len(trt_ver)>0:
+ targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
+ else:
+ targetlib="lib/libnvinfer.so"
+ else:
+ soname=repo_ctx.execute([objdump,"-p",origLib])
+ for l in soname.stdout.splitlines():
+ if "SONAME" in l:
+ lib=l.strip().split(" ")[-1]
+ targetlib="lib/%s"%(lib)
+
+ if len(trt_ver)>0:
+ repo_ctx.symlink(origLib,targetlib)
+ else:
+ repo_ctx.symlink(origLib,targetlib)
+ grule=('genrule(\n name = "trtlinks",\n'+
+ ' outs = [\n "%s",\n "include/NvInfer.h",\n "include/NvUtils.h",\n ],\n'%targetlib +
+ ' cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
+ '&&\n ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
+ '&&\n ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
+ repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
+ {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
+ False)
+ repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
+ {"%{trt_configured}":"True"},False)
+
+trt_repository=repository_rule(
+ implementation= _trt_repo_impl,
+ local=True,
+ environ=[
+ "TF_NEED_TENSORRT",
+ _TF_TENSORRT_VERSION,
+ _TENSORRT_INSTALLATION_PATH,
+ ],
+ )