aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow')
-rw-r--r--tensorflow/contrib/tensorrt/BUILD9
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc6
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc18
3 files changed, 13 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD
index fd0f97f3af..e7b3fe38e5 100644
--- a/tensorflow/contrib/tensorrt/BUILD
+++ b/tensorflow/contrib/tensorrt/BUILD
@@ -87,6 +87,7 @@ cc_library(
":trt_plugins",
":trt_resources",
":trt_conversion",
+ ":utils",
"//tensorflow/core:gpu_headers_lib",
"//tensorflow/core:lib_proto_parsing",
"//tensorflow/core:stream_executor_headers_lib",
@@ -94,7 +95,7 @@ cc_library(
] + if_tensorrt([
"@local_config_tensorrt//:nv_infer",
]) + tf_custom_op_library_additional_deps(),
- # TODO(laigd)
+ # TODO(laigd): fix this by merging header file in cc file.
alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs
)
@@ -232,6 +233,7 @@ tf_cuda_library(
":trt_plugins",
":trt_logging",
":trt_resources",
+ ":utils",
"//tensorflow/core/grappler/clusters:cluster",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
@@ -337,3 +339,8 @@ py_test(
"//tensorflow/python:framework_test_lib",
],
)
+
+cc_library(
+ name = "utils",
+ hdrs = ["convert/utils.h"],
+)
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index bd6ed2d593..9f0b3ef5dd 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -423,10 +423,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
info.precision_mode == INT8MODE) {
// Create static engine and for int8 test validity of the engine.
Logger trt_logger;
- auto builder = std::unique_ptr<
- nvinfer1::IBuilder, std::function<void(nvinfer1::IBuilder*)>>(
- nvinfer1::createInferBuilder(trt_logger),
- [](nvinfer1::IBuilder* p) { if (p) p->destroy(); });
+ TrtUniquePtrType<nvinfer1::IBuilder> builder(
+ nvinfer1::createInferBuilder(trt_logger));
builder->setMaxBatchSize(max_batch_size);
if (info.precision_mode == FP16MODE) builder->setHalf2Mode(true);
builder->setMaxWorkspaceSize(info.max_workspace_size_bytes);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index a252ea67df..69d7b765fa 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -420,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights,
}
}
-struct InferDeleter {
- template <typename T>
- void operator()(T* obj) const {
- if (obj) {
- obj->destroy();
- }
- }
-};
-
-template <typename T>
-inline std::shared_ptr<T> infer_object(T* obj) {
- return std::shared_ptr<T>(obj, InferDeleter());
-}
-
class Converter;
using OpConverter =
@@ -2151,7 +2137,8 @@ tensorflow::Status ConvertSubGraphDefToEngine(
bool* convert_successfully) {
engine->reset();
if (convert_successfully) *convert_successfully = false;
- auto trt_network = infer_object(builder->createNetwork());
+ auto trt_network =
+ TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!trt_network) {
return tensorflow::errors::Internal(
"Failed to create TensorRT network object");
@@ -2207,6 +2194,7 @@ tensorflow::Status ConvertSubGraphDefToEngine(
nvinfer1::ITensor* input_tensor = converter.network()->addInput(
node_name.c_str(), dtype, input_dim_pseudo_chw);
if (!input_tensor) {
+ // TODO(aaroey): remove StrCat when constructing errors.
return tensorflow::errors::InvalidArgument(
StrCat("Failed to create Input layer tensor ", node_name,
" rank=", shape.dims() - 1));