aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/BUILD
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/BUILD')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD44
1 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index 675f0b1fd6..7a8a71ac7f 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -67,6 +67,7 @@ tf_cuda_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
@@ -86,6 +87,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
":trt_logging",
+ ":trt_plugins",
":trt_resources",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
@@ -232,6 +234,7 @@ tf_cuda_library(
],
deps = [
":segment",
+ ":trt_plugins",
":trt_logging",
":trt_resources",
"//tensorflow/core/grappler/clusters:cluster",
@@ -263,7 +266,6 @@ cc_library(
"segment/segment.h",
"segment/union_find.h",
],
- linkstatic = 1,
deps = [
"//tensorflow/core:graph",
"//tensorflow/core:lib_proto_parsing",
@@ -286,6 +288,46 @@ tf_cc_test(
],
)
+# Library for the plugin factory
+tf_cuda_library(
+ name = "trt_plugins",
+ srcs = [
+ "plugin/trt_plugin.cc",
+ "plugin/trt_plugin_factory.cc",
+ "plugin/trt_plugin_utils.cc",
+ ],
+ hdrs = [
+ "plugin/trt_plugin.h",
+ "plugin/trt_plugin_factory.h",
+ "plugin/trt_plugin_utils.h",
+ ],
+ deps = [
+ "//tensorflow/core:framework_lite",
+ "//tensorflow/core:platform_base",
+ ] + if_tensorrt([
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
+tf_cuda_cc_test(
+ name = "trt_plugin_factory_test",
+ size = "small",
+ srcs = ["plugin/trt_plugin_factory_test.cc"],
+ tags = [
+ "manual",
+ "notap",
+ ],
+ deps = [
+ ":trt_plugins",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ] + if_tensorrt([
+ "@local_config_cuda//cuda:cuda_headers",
+ "@local_config_tensorrt//:nv_infer",
+ ]),
+)
+
py_test(
name = "tf_trt_integration_test",
srcs = ["test/tf_trt_integration_test.py"],