aboutsummaryrefslogtreecommitdiffhomepage
path: root/third_party/tensorrt/build_defs.bzl
blob: 392c5e06214c953821fc96aabdf4780b71b37e20 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
# -*- python -*-
"""
 add a repo_generator rule for tensorrt

"""

_TENSORRT_INSTALLATION_PATH="TENSORRT_INSTALL_PATH"
_TF_TENSORRT_VERSION="TF_TENSORRT_VERSION"

def _is_trt_enabled(repo_ctx):
    if "TF_NEED_TENSORRT" in repo_ctx.os.environ:
        enable_trt = repo_ctx.os.environ["TF_NEED_TENSORRT"].strip()
        return enable_trt == "1"
    return False

def _dummy_repo(repo_ctx):

    repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
                      {"%{tensorrt_lib}":"","%{tensorrt_genrules}":""},
                      False)
    repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
                      {"%{trt_configured}":"False"},False)
    repo_ctx.file("include/NvUtils.h","",False)
    repo_ctx.file("include/NvInfer.h","",False)

def _trt_repo_impl(repo_ctx):
    """
    Implements local_config_tensorrt
    """

    if not _is_trt_enabled(repo_ctx):
        _dummy_repo(repo_ctx)
        return
    trt_libdir=repo_ctx.os.environ[_TENSORRT_INSTALLATION_PATH]
    trt_ver=repo_ctx.os.environ[_TF_TENSORRT_VERSION]
# if deb installation
# once a standardized installation between tar and deb
# is done, we don't need this
    if trt_libdir == '/usr/lib/x86_64-linux-gnu':
        incPath='/usr/include/x86_64-linux-gnu'
        incname='/usr/include/x86_64-linux-gnu/NvInfer.h'
    else:
        incPath=str(repo_ctx.path("%s/../include"%trt_libdir).realpath)
        incname=incPath+'/NvInfer.h'
    if len(trt_ver)>0:
        origLib="%s/libnvinfer.so.%s"%(trt_libdir,trt_ver)
    else:
        origLib="%s/libnvinfer.so"%trt_libdir        
    objdump=repo_ctx.which("objdump")
    if objdump == None:
        if len(trt_ver)>0:
            targetlib="lib/libnvinfer.so.%s"%(trt_ver[0])
        else:
            targetlib="lib/libnvinfer.so"
    else:
        soname=repo_ctx.execute([objdump,"-p",origLib])
        for l in soname.stdout.splitlines():
            if "SONAME" in l:
                lib=l.strip().split(" ")[-1]
                targetlib="lib/%s"%(lib)
    
    if len(trt_ver)>0:
        repo_ctx.symlink(origLib,targetlib)
    else:
        repo_ctx.symlink(origLib,targetlib)
    grule=('genrule(\n    name = "trtlinks",\n'+
           '    outs = [\n    "%s",\n    "include/NvInfer.h",\n    "include/NvUtils.h",\n     ],\n'%targetlib +
           '    cmd="""ln -sf %s $(@D)/%s '%(origLib,targetlib) +
           '&&\n    ln -sf %s $(@D)/include/NvInfer.h '%(incname) +
           '&&\n    ln -sf %s/NvUtils.h $(@D)/include/NvUtils.h""",\n)\n'%(incPath))
    repo_ctx.template("BUILD",Label("//third_party/tensorrt:BUILD.tpl"),
                      {"%{tensorrt_lib}":'"%s"'%targetlib,"%{tensorrt_genrules}":grule},
                      False)
    repo_ctx.template("build_defs.bzl",Label("//third_party/tensorrt:build_defs.bzl.tpl"),
                      {"%{trt_configured}":"True"},False)

trt_repository=repository_rule(
    implementation= _trt_repo_impl,
    local=True,
    environ=[
        "TF_NEED_TENSORRT",
        _TF_TENSORRT_VERSION,
        _TENSORRT_INSTALLATION_PATH,
        ],
    )