diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/BUILD')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 204 |
1 files changed, 203 insertions, 1 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index 28f571e1f0..65a0e903a7 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -1,5 +1,6 @@ # Description: -# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow. +# Wrap NVIDIA TensorRT (http://developer.nvidia.com/tensorrt) with tensorflow +# and provide TensorRT operators and converter package. # APIs are meant to change over time. package(default_visibility = ["//tensorflow:__subpackages__"]) @@ -8,7 +9,19 @@ licenses(["notice"]) # Apache 2.0 exports_files(["LICENSE"]) +load( + "//tensorflow:tensorflow.bzl", + "tf_cc_test", + "tf_copts", + "tf_cuda_library", + "tf_custom_op_library", + "tf_custom_op_library_additional_deps", + "tf_gen_op_libs", + "tf_gen_op_wrapper_py", +) load("//tensorflow:tensorflow.bzl", "tf_cuda_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library") +load("//tensorflow:tensorflow.bzl", "tf_py_wrap_cc") load( "@local_config_tensorrt//:build_defs.bzl", "if_tensorrt", @@ -32,6 +45,195 @@ tf_cuda_cc_test( ]), ) +tf_custom_op_library( + name = "python/ops/_trt_engine_op.so", + srcs = ["ops/trt_engine_op.cc"], + deps = [ + ":trt_engine_op_kernel", + ":trt_shape_function", + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_library( + name = "trt_shape_function", + srcs = ["shape_fn/trt_shfn.cc"], + hdrs = ["shape_fn/trt_shfn.h"], + visibility = ["//visibility:public"], + deps = [ + ":trt_logging", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +cc_library( + name = "trt_engine_op_kernel", + srcs = ["kernels/trt_engine_op.cc"], + hdrs = ["kernels/trt_engine_op.h"], + copts = tf_copts(), + deps = [ + ":trt_logging", + "//tensorflow/core:gpu_headers_lib", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:stream_executor_headers_lib", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), + # TODO(laigd) + alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs +) + +tf_gen_op_libs( + op_lib_names = ["trt_engine_op"], + deps = if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_cuda_library( + name = "trt_logging", + srcs = ["log/trt_logger.cc"], + hdrs = ["log/trt_logger.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/core:lib_proto_parsing", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), +) + +tf_gen_op_wrapper_py( + name = "trt_engine_op", + deps = [ + ":trt_engine_op_op_lib", + ":trt_logging", + ":trt_shape_function", + ], +) + +tf_custom_op_py_library( + name = "trt_engine_op_loader", + srcs = ["python/ops/trt_engine_op.py"], + dso = [ + ":python/ops/_trt_engine_op.so", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]), + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework_for_generated_wrappers", + "//tensorflow/python:resources", + ], +) + +py_library( + name = "init_py", + srcs = [ + "__init__.py", + "python/__init__.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":trt_convert_py", + ":trt_ops_py", + ], +) + +py_library( + name = "trt_ops_py", + srcs_version = "PY2AND3", + deps = [ + ":trt_engine_op", + ":trt_engine_op_loader", + ], +) + +py_library( + name = "trt_convert_py", + srcs = ["python/trt_convert.py"], + srcs_version = "PY2AND3", + deps = [ + ":wrap_conversion", + ], +) + +tf_py_wrap_cc( + name = "wrap_conversion", + srcs = ["trt_conversion.i"], + copts = tf_copts(), + deps = [ + ":trt_conversion", + "//tensorflow/core:framework_lite", + "//util/python:python_headers", + ], +) + +# Library for the node-level conversion portion of TensorRT operation creation +tf_cuda_library( + name = "trt_conversion", + srcs = [ + "convert/convert_graph.cc", + "convert/convert_nodes.cc", + ], + hdrs = [ + "convert/convert_graph.h", + "convert/convert_nodes.h", + ], + deps = [ + ":segment", + ":trt_logging", + "//tensorflow/core/grappler:grappler_item", + "//tensorflow/core/grappler:utils", + "//tensorflow/core:framework", + "//tensorflow/core:framework_lite", + "//tensorflow/core:graph", + "//tensorflow/core:lib", + "//tensorflow/core:lib_internal", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/grappler:devices", + "//tensorflow/core/grappler/clusters:virtual_cluster", + "//tensorflow/core/grappler/costs:graph_properties", + "//tensorflow/core/grappler/optimizers:constant_folding", + "//tensorflow/core/grappler/optimizers:layout_optimizer", + ] + if_tensorrt([ + "@local_config_tensorrt//:nv_infer", + ]) + tf_custom_op_library_additional_deps(), +) + +# Library for the segmenting portion of TensorRT operation creation +cc_library( + name = "segment", + srcs = ["segment/segment.cc"], + hdrs = [ + "segment/segment.h", + "segment/union_find.h", + ], + linkstatic = 1, + deps = [ + "//tensorflow/core:graph", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@protobuf_archive//:protobuf_headers", + ], +) + +tf_cc_test( + name = "segment_test", + size = "small", + srcs = ["segment/segment_test.cc"], + deps = [ + ":segment", + "//tensorflow/c:c_api", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + ], +) + filegroup( name = "all_files", srcs = glob( |