diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/BUILD')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 44 |
1 files changed, 43 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 675f0b1fd6..7a8a71ac7f 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -67,6 +67,7 @@ tf_cuda_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), @@ -86,6 +87,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":trt_logging", + ":trt_plugins", ":trt_resources", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", @@ -232,6 +234,7 @@ tf_cuda_library( ], deps = [ ":segment", + ":trt_plugins", ":trt_logging", ":trt_resources", "//tensorflow/core/grappler/clusters:cluster", @@ -263,7 +266,6 @@ cc_library( "segment/segment.h", "segment/union_find.h", ], - linkstatic = 1, deps = [ "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", @@ -286,6 +288,46 @@ tf_cc_test( ], ) +# Library for the plugin factory +tf_cuda_library( + name = "trt_plugins", + srcs = [ + "plugin/trt_plugin.cc", + "plugin/trt_plugin_factory.cc", + "plugin/trt_plugin_utils.cc", + ], + hdrs = [ + "plugin/trt_plugin.h", + "plugin/trt_plugin_factory.h", + "plugin/trt_plugin_utils.h", + ], + deps = [ + "//tensorflow/core:framework_lite", + "//tensorflow/core:platform_base", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_cc_test( + name = "trt_plugin_factory_test", + size = "small", + srcs = ["plugin/trt_plugin_factory_test.cc"], + tags = [ + "manual", + "notap", + ], + deps = [ + ":trt_plugins", + "//tensorflow/core:lib", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ] + if_tensorrt([ + "@local_config_cuda//cuda:cuda_headers", + "@local_config_tensorrt//:nv_infer", + ]), +) + py_test( name = "tf_trt_integration_test", srcs = ["test/tf_trt_integration_test.py"], |