diff options
Diffstat (limited to 'tensorflow')
-rw-r--r-- | tensorflow/contrib/tensorrt/BUILD | 9 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_graph.cc | 6 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/convert/convert_nodes.cc | 18 |
3 files changed, 13 insertions, 20 deletions
diff --git a/tensorflow/contrib/tensorrt/BUILD b/tensorflow/contrib/tensorrt/BUILD index fd0f97f3af..e7b3fe38e5 100644 --- a/tensorflow/contrib/tensorrt/BUILD +++ b/tensorflow/contrib/tensorrt/BUILD @@ -87,6 +87,7 @@ cc_library( ":trt_plugins", ":trt_resources", ":trt_conversion", + ":utils", "//tensorflow/core:gpu_headers_lib", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:stream_executor_headers_lib", @@ -94,7 +95,7 @@ cc_library( ] + if_tensorrt([ "@local_config_tensorrt//:nv_infer", ]) + tf_custom_op_library_additional_deps(), - # TODO(laigd) + # TODO(laigd): fix this by merging header file in cc file. alwayslink = 1, # buildozer: disable=alwayslink-with-hdrs ) @@ -232,6 +233,7 @@ tf_cuda_library( ":trt_plugins", ":trt_logging", ":trt_resources", + ":utils", "//tensorflow/core/grappler/clusters:cluster", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer", "//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry", @@ -337,3 +339,8 @@ py_test( "//tensorflow/python:framework_test_lib", ], ) + +cc_library( + name = "utils", + hdrs = ["convert/utils.h"], +) diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc index bd6ed2d593..9f0b3ef5dd 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc @@ -423,10 +423,8 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph, info.precision_mode == INT8MODE) { // Create static engine and for int8 test validity of the engine. Logger trt_logger; - auto builder = std::unique_ptr< - nvinfer1::IBuilder, std::function<void(nvinfer1::IBuilder*)>>( - nvinfer1::createInferBuilder(trt_logger), - [](nvinfer1::IBuilder* p) { if (p) p->destroy(); }); + TrtUniquePtrType<nvinfer1::IBuilder> builder( + nvinfer1::createInferBuilder(trt_logger)); builder->setMaxBatchSize(max_batch_size); if (info.precision_mode == FP16MODE) builder->setHalf2Mode(true); builder->setMaxWorkspaceSize(info.max_workspace_size_bytes); diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc index a252ea67df..69d7b765fa 100644 --- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc +++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc @@ -420,20 +420,6 @@ void ReorderRSCKToKCRS(const TRT_ShapedWeights& iweights, } } -struct InferDeleter { - template <typename T> - void operator()(T* obj) const { - if (obj) { - obj->destroy(); - } - } -}; - -template <typename T> -inline std::shared_ptr<T> infer_object(T* obj) { - return std::shared_ptr<T>(obj, InferDeleter()); -} - class Converter; using OpConverter = @@ -2151,7 +2137,8 @@ tensorflow::Status ConvertSubGraphDefToEngine( bool* convert_successfully) { engine->reset(); if (convert_successfully) *convert_successfully = false; - auto trt_network = infer_object(builder->createNetwork()); + auto trt_network = + TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork()); if (!trt_network) { return tensorflow::errors::Internal( "Failed to create TensorRT network object"); @@ -2207,6 +2194,7 @@ tensorflow::Status ConvertSubGraphDefToEngine( nvinfer1::ITensor* input_tensor = converter.network()->addInput( node_name.c_str(), dtype, input_dim_pseudo_chw); if (!input_tensor) { + // TODO(aaroey): remove StrCat when constructing errors. return tensorflow::errors::InvalidArgument( StrCat("Failed to create Input layer tensor ", node_name, " rank=", shape.dims() - 1)); |