aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_graph.cc28
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.cc37
-rw-r--r--tensorflow/contrib/tensorrt/convert/convert_nodes.h10
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc93
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.h26
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc13
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h26
7 files changed, 113 insertions, 120 deletions
diff --git a/tensorflow/contrib/tensorrt/convert/convert_graph.cc b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
index 7dcd30b0b2..ba7d3b5f86 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_graph.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_graph.cc
@@ -424,31 +424,25 @@ tensorflow::Status CreateTRTNode(tensorflow::Graph* graph,
string segment_string;
if (info.engine_type == EngineInfo::EngineType::TRTStatic ||
info.precision_mode == INT8MODE) {
- // Create static engine and for int8 test validity of the engine. We can not
- // allow engine to fail at the calibration time. So we are constructing a
- // FP32 engine here to check its validity. If it is a valid engine then we
- // put the serialized graphdef to the op. Otherwise we skip node creation
- // for this engine.
+ // Create static engine for fp32/fp16 mode, and test validity of the engine
+ // for int8 mode. We don't want engine to fail at the calibration time.
+ // So we are constructing a FP32 engine here to check its validity, and if
+ // it is a valid engine then we put the serialized graphdef to the op.
+ // Otherwise we skip node creation for this engine.
Logger trt_logger;
- TrtUniquePtrType<nvinfer1::IBuilder> builder(
- nvinfer1::createInferBuilder(trt_logger));
- builder->setMaxBatchSize(max_batch_size);
- if (info.precision_mode == FP16MODE) builder->setHalf2Mode(true);
- builder->setMaxWorkspaceSize(info.max_workspace_size_bytes);
-#if NV_TENSORRT_MAJOR > 3
- builder->setGpuAllocator(alloc);
-#endif
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
// TODO(sami): What happens if 1st dim is not batch?
TF_RETURN_IF_ERROR(ConvertGraphDefToEngine(
- info.segment_graph_def, info.precision_mode, shapes, builder.get(),
- &engine, /*convert_successfully=*/nullptr));
+ info.segment_graph_def,
+ info.precision_mode == INT8MODE ? FP32MODE : info.precision_mode,
+ max_batch_size, info.max_workspace_size_bytes, shapes, &trt_logger,
+ alloc, /*calibrator=*/nullptr, &engine,
+ /*convert_successfully=*/nullptr));
TrtUniquePtrType<nvinfer1::IHostMemory> engine_data(engine->serialize());
segment_string =
string((const char*)engine_data->data(), engine_data->size());
if (info.precision_mode == INT8MODE) {
- // See above comment on the reason why not putting this inside the 'else'
- // branch.
+ // See above comment about why not putting this inside the 'else' branch.
segment_string = info.segment_graph_def.SerializeAsString();
}
} else {
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
index 5608761206..b5214b461a 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.cc
@@ -433,7 +433,7 @@ class Converter {
OpConverter plugin_converter_;
nvinfer1::INetworkDefinition* trt_network_;
std::list<std::vector<uint8_t>> temp_bufs_;
- tensorflow::tensorrt::TRTWeightStore* weight_store_;
+ TRTWeightStore* weight_store_;
bool fp16_;
void register_op_converters();
tensorflow::Status get_inputs(const tensorflow::NodeDef& node_def,
@@ -475,11 +475,11 @@ class Converter {
public:
explicit Converter(nvinfer1::INetworkDefinition* trt_network,
- tensorflow::tensorrt::TRTWeightStore* ws, bool fp16)
+ TRTWeightStore* ws, bool fp16)
: trt_network_(trt_network), weight_store_(ws), fp16_(fp16) {
this->register_op_converters();
}
- tensorflow::tensorrt::TRTWeightStore* weight_store() { return weight_store_; }
+ TRTWeightStore* weight_store() { return weight_store_; }
TRT_ShapedWeights get_temp_weights(tensorflow::DataType type,
nvinfer1::Dims shape) {
TRT_ShapedWeights weights(type, nullptr, shape);
@@ -2130,21 +2130,44 @@ void Converter::register_op_converters() {
} // namespace
tensorflow::Status ConvertGraphDefToEngine(
- const tensorflow::GraphDef& gdef, int precision_mode,
+ const tensorflow::GraphDef& gdef,
+ int precision_mode,
+ int max_batch_size,
+ size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
- nvinfer1::IBuilder* builder,
+ Logger* logger,
+ nvinfer1::IGpuAllocator* allocator,
+ TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully) {
engine->reset();
if (convert_successfully) *convert_successfully = false;
+
+ // Create the builder.
+ TrtUniquePtrType<nvinfer1::IBuilder> builder(
+ nvinfer1::createInferBuilder(*logger));
+ builder->setMaxBatchSize(max_batch_size);
+ // TODO(aaroey): use the allocator to allocate the TRT workspace.
+ builder->setMaxWorkspaceSize(max_workspace_size_bytes);
+#if NV_TENSORRT_MAJOR > 3
+ builder->setGpuAllocator(allocator);
+#endif
+ if (precision_mode == FP16MODE) {
+ builder->setHalf2Mode(true);
+ } else if (precision_mode == INT8MODE) {
+ builder->setInt8Mode(true);
+ builder->setInt8Calibrator(calibrator);
+ }
+
+ // Create the network.
auto trt_network =
TrtUniquePtrType<nvinfer1::INetworkDefinition>(builder->createNetwork());
if (!trt_network) {
return tensorflow::errors::Internal(
"Failed to create TensorRT network object");
}
- auto ws = std::unique_ptr<tensorflow::tensorrt::TRTWeightStore>(
- new TRTWeightStore());
+ auto ws = std::unique_ptr<TRTWeightStore>(new TRTWeightStore());
+
// Build the network
VLOG(1) << "Starting engine conversion ";
Converter converter(trt_network.get(), ws.get(), precision_mode == FP16MODE);
diff --git a/tensorflow/contrib/tensorrt/convert/convert_nodes.h b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
index b357da0d84..2da4edf7f5 100644
--- a/tensorflow/contrib/tensorrt/convert/convert_nodes.h
+++ b/tensorflow/contrib/tensorrt/convert/convert_nodes.h
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/contrib/tensorrt/convert/utils.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
+#include "tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/grappler/costs/graph_properties.h"
@@ -119,9 +120,14 @@ tensorflow::Status ConvertSegmentToGraphDef(
// is successful. This is different than successfully building the engine:
// building can still fail afterwards.
tensorflow::Status ConvertGraphDefToEngine(
- const tensorflow::GraphDef& gdef, int precision_mode,
+ const tensorflow::GraphDef& gdef,
+ int precision_mode,
+ int max_batch_size,
+ size_t max_workspace_size_bytes,
const std::vector<tensorflow::PartialTensorShape>& input_shapes,
- nvinfer1::IBuilder* builder,
+ Logger* logger,
+ nvinfer1::IGpuAllocator* allocator,
+ TRTInt8Calibrator* calibrator,
TrtUniquePtrType<nvinfer1::ICudaEngine>* engine,
bool* convert_successfully);
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index 4b45281f51..d12f738ac5 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -36,7 +36,6 @@ namespace tensorflow {
namespace tensorrt {
static Logger logger;
using ::nvinfer1::IRuntime;
-using ::nvinfer1::Dims;
using ::tensorflow::strings::StrAppend;
using ::tensorflow::strings::StrCat;
@@ -441,6 +440,7 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
#if NV_TENSORRT_MAJOR > 3
auto allocator = GetAllocator(ctx);
if (allocator == nullptr) {
+ // GetAllocator already set the Status.
return null_pair;
};
infer->setGpuAllocator(allocator);
@@ -464,39 +464,27 @@ TRTEngineOp::EngineCtxPair& TRTEngineOp::GetEngine(int batch_size,
auto engine_it = engine_map_.find(batch_size);
if (engine_it == engine_map_.end() &&
engine_map_.size() < (size_t)max_cached_engines_) {
- TrtUniquePtrType<nvinfer1::IBuilder> builder(
- nvinfer1::createInferBuilder(logger));
+ nvinfer1::IGpuAllocator* allocator = nullptr;
#if NV_TENSORRT_MAJOR > 3
- auto allocator = GetAllocator(ctx);
+ allocator = GetAllocator(ctx);
if (allocator == nullptr) {
// GetAllocator already set the Status.
return null_pair;
}
- builder->setGpuAllocator(allocator);
#endif
- VLOG(0) << name() << " Constructing a new engine with batch size "
- << batch_size;
- builder->setMaxBatchSize(batch_size);
- if (precision_mode_ == convert::FP16MODE) {
- builder->setHalf2Mode(true);
- } else if (precision_mode_ == convert::INT8MODE) {
- builder->setInt8Mode(true);
- // Up to this point, calibrator_ can never be empty, since otherwise it
- // means calibration_mode_ is true and this path won't get executed.
- builder->setInt8Calibrator(calibrator_.get());
- }
- // TODO(aaroey): use the allocator to allocate the TRT workspace.
- builder->setMaxWorkspaceSize(workspace_size_);
std::vector<tensorflow::PartialTensorShape> shapes;
for (int i = 0; i < ctx->num_inputs(); ++i) {
shapes.emplace_back(ctx->input(i).shape());
}
TrtUniquePtrType<nvinfer1::ICudaEngine> engine;
bool convert_successfully = false;
- VLOG(1) << "Calling conversion for " << batch_size << " " << name();
+ VLOG(0) << name() << " Constructing a new engine with batch size "
+ << batch_size;
+ // Up to this point, calibrator_ can never be empty, since otherwise it
+ // means calibration_mode_ is true and this path won't get executed.
auto status = convert::ConvertGraphDefToEngine(
- segment_graph_, precision_mode_, shapes, builder.get(), &engine,
- &convert_successfully);
+ segment_graph_, precision_mode_, batch_size, workspace_size_, shapes,
+ &logger, allocator, calibrator_.get(), &engine, &convert_successfully);
if (!status.ok()) {
if (convert_successfully) {
// This means it fail to build the engine even when the network is built
@@ -522,9 +510,7 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
TRTCalibrationResource** cr) {
auto cres = new TRTCalibrationResource();
*cr = cres;
- cres->logger_ = new Logger();
-
-#if NV_TENSORRT_MAJOR > 3
+ // Get the allocator.
auto alloc = ctx->device()->GetAllocator(tensorflow::AllocatorAttributes());
if (!alloc) {
LOG(WARNING) << "Can't get device allocator will not be able to "
@@ -533,11 +519,10 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
} else {
cres->allocator_.reset(new TRTDeviceAllocator(alloc));
}
-#endif
- int batch_size = ctx->input(0).dim_size(0);
+ // Get the input shapes.
+ const int batch_size = ctx->input(0).dim_size(0);
+ const int num_inputs = ctx->num_inputs();
std::vector<tensorflow::PartialTensorShape> shapes;
- int num_inputs = ctx->num_inputs();
- // first run instantiate calibrator
dev_tensors_.resize(num_inputs);
VLOG(1) << " Constructing calibrator";
for (int i = 0; i < num_inputs; i++) {
@@ -557,51 +542,45 @@ tensorflow::Status TRTEngineOp::AllocateCalibrationResources(
StrCat(kInputPHName, i),
std::pair<void*, size_t>(device_address, device_tensor->TotalBytes()));
}
- cres->calibrator_ =
- new TRTInt8Calibrator(device_buffers_, batch_size, name());
- string label(name());
+ cres->calibrator_.reset(
+ new TRTInt8Calibrator(device_buffers_, batch_size, name()));
+ const string label(name());
auto segment_graph = &segment_graph_;
- int cuda_device = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
- if (cuda_device < 0) {
+ const int cuda_gpu_id = ctx->device()->tensorflow_gpu_device_info()->gpu_id;
+ if (cuda_gpu_id < 0) {
LOG(ERROR) << "Can't get gpu_device_info from context->device()";
return tensorflow::errors::InvalidArgument(
"Context->device doesn't contain device info!");
}
- int workspace_size = workspace_size_;
- cres->thr_ = new std::thread([cres, label, segment_graph, shapes, cuda_device,
- batch_size, workspace_size]() {
- VLOG(0) << "Starting calibration thread on device " << cuda_device
+ const int64 workspace_size_bytes = workspace_size_;
+ cres->thr_.reset(new std::thread([cres, label, segment_graph, shapes,
+ cuda_gpu_id, workspace_size_bytes]() {
+ VLOG(0) << "Starting calibration thread on device " << cuda_gpu_id
<< ", Calibration Resource @ " << cres;
- auto err = cudaSetDevice(cuda_device);
+ auto err = cudaSetDevice(cuda_gpu_id);
if (err != cudaSuccess) {
- VLOG(0) << "Couldn't set cuda device to " << cuda_device
- << " in calibration thread";
+ // TODO(aaroey): should return error here.
+ LOG(ERROR) << "Couldn't set cuda device to " << cuda_gpu_id
+ << " in calibration thread";
}
- // initialize builder here
- cres->builder_.reset(nvinfer1::createInferBuilder(*(cres->logger_)));
- // TODO(aaroey): maybe setting the max batch size using the python
- // calibration wrapper class.
- cres->builder_->setMaxBatchSize(batch_size);
-#if NV_TENSORRT_MAJOR > 3
- cres->builder_->setGpuAllocator(cres->allocator_.get());
-#endif
- cres->builder_->setInt8Mode(true);
- cres->builder_->setMaxWorkspaceSize(workspace_size);
- cres->builder_->setInt8Calibrator(cres->calibrator_);
// ConvertGraphDefToEngine() will try to build the engine. This thread
// will loop inside buildCudaEngine() consuming the calibration data
// that is set by the TF op, and drive the builder until calibrator returns
// false. Engine is discarded after calibration table is generated
+ //
+ // TODO(aaroey): maybe setting the max batch size using the python
+ // calibration wrapper class.
auto s = convert::ConvertGraphDefToEngine(
- *segment_graph, convert::INT8MODE, shapes, cres->builder_.get(),
- &cres->engine_, /*convert_successfully=*/nullptr);
+ *segment_graph, convert::INT8MODE, cres->calibrator_->getBatchSize(),
+ workspace_size_bytes, shapes, &cres->logger_, cres->allocator_.get(),
+ cres->calibrator_.get(), &cres->engine_,
+ /*convert_successfully=*/nullptr);
if (!s.ok()) {
- LOG(ERROR)
- << "Calibration failed. Engine will not be calibrated! Error is" << s;
- cres->calibrator_->setDone(); // ignore further pushes
+ LOG(ERROR) << "Calibration failed: " << s;
+ cres->calibrator_->setDone(); // Ignore further pushes
}
VLOG(1) << "Calibration loop terminated " << label;
- });
+ }));
VLOG(1) << "initialized calibrator resource";
return tensorflow::Status::OK();
}
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
index cb43403130..0d2f9e8a9d 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/contrib/tensorrt/convert/utils.h"
+#include "tensorflow/contrib/tensorrt/log/trt_logger.h"
#include "tensorflow/contrib/tensorrt/resources/trt_allocator.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
@@ -46,25 +47,24 @@ class TRTEngineOp : public AsyncOpKernel {
explicit TRTEngineOp(OpKernelConstruction* context);
void ComputeAsync(OpKernelContext* context,
- tensorflow::AsyncOpKernel::DoneCallback done) override;
+ AsyncOpKernel::DoneCallback done) override;
~TRTEngineOp();
private:
// Execute calibration
- void ExecuteCalibration(tensorflow::OpKernelContext* ctx,
+ void ExecuteCalibration(OpKernelContext* ctx,
AsyncHelper* helper);
// Construct a function handle for executing native funcdef graph
- tensorflow::Status ConstructFunctionHandle(tensorflow::OpKernelContext* ctx);
+ Status ConstructFunctionHandle(OpKernelContext* ctx);
// Execute replaced native segment as function Op.
- void ExecuteNativeSegment(tensorflow::OpKernelContext* ctx,
+ void ExecuteNativeSegment(OpKernelContext* ctx,
AsyncHelper* helper);
// Allocate necessary resources for calibration
- tensorflow::Status AllocateCalibrationResources(
- tensorflow::OpKernelContext* ctx,
- tensorflow::tensorrt::TRTCalibrationResource** cr);
+ Status AllocateCalibrationResources(
+ OpKernelContext* ctx, TRTCalibrationResource** cr);
// TODO(samikama): context should go to a resource manager!
typedef std::pair<TrtUniquePtrType<nvinfer1::ICudaEngine>,
@@ -92,13 +92,13 @@ class TRTEngineOp : public AsyncOpKernel {
string funcdef_name_;
// GraphDef representation of the segment.
- tensorflow::GraphDef segment_graph_;
+ GraphDef segment_graph_;
// Lookup table for temporary staging areas of input tensors for calibration.
std::unordered_map<string, std::pair<void*, size_t>> device_buffers_;
// Temporary staging areas for calibration inputs.
- std::vector<tensorflow::PersistentTensor> dev_tensors_;
+ std::vector<PersistentTensor> dev_tensors_;
// Engine Precision mode.
int precision_mode_;
@@ -120,9 +120,11 @@ class TRTEngineOp : public AsyncOpKernel {
// Maximum number of cached engines
int max_cached_engines_;
- tensorflow::int64 workspace_size_;
- tensorflow::mutex engine_mutex_;
- tensorflow::FunctionLibraryRuntime::Handle native_func_;
+ int64 workspace_size_;
+ mutex engine_mutex_;
+ FunctionLibraryRuntime::Handle native_func_;
+
+ // The finalized calibrator for inference.
std::unique_ptr<TRTInt8Calibrator> calibrator_;
};
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 9c1c306947..59ae860bc0 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -51,8 +51,8 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- while ((calib_running_ || batch_is_set_) &&
- !done_) { // wait while calibration is running
+ // wait while calibration is running.
+ while ((calib_running_ || batch_is_set_) && !done_) {
cond_.wait(lock);
}
if (done_) return false;
@@ -66,8 +66,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}
const auto& d = devptr->second;
- // TODO(aaroey): we should not use sync copy on default stream. Make sure
- // stream->ThenMemcpy() is used in future PRs.
// TODO(sami,aaroey): Need to figure out a way to ensure synchronization
// between stream, perhaps using a tensor?
auto status = cudaMemcpyAsync(d.first, it.second, d.second,
@@ -91,12 +89,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
tensorflow::mutex_lock lock(cond_mtx_);
calib_running_ = false;
cond_.notify_all();
- while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
+ // wait until new batch arrives
+ while ((!batch_is_set_ && !done_)) {
cond_.wait(lock);
}
- if (done_) {
- return false;
- }
+ if (done_) return false;
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 43734bbdd8..76863503bd 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -38,11 +38,6 @@ namespace tensorrt {
class TRTCalibrationResource : public tensorflow::ResourceBase {
public:
- TRTCalibrationResource()
- : calibrator_(nullptr),
- logger_(nullptr),
- thr_(nullptr) {}
-
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_.reset();
@@ -50,9 +45,6 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
// We need to manually destroy the builder and engine before the allocator
// is destroyed.
allocator_.reset();
- delete thr_;
- delete logger_;
- delete calibrator_;
}
string DebugString() override {
@@ -60,22 +52,22 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
using std::hex;
using std::dec;
using std::endl;
- oss << " Calibrator = " << hex << calibrator_ << dec << endl
- << " Builder = " << hex << builder_.get() << dec << endl
- << " Engine = " << hex << engine_.get() << dec << endl
- << " Logger = " << hex << logger_ << dec << endl
- << " Allocator = " << hex << allocator_.get() << dec << endl
- << " Thread = " << hex << thr_ << dec << endl;
+ oss << " Calibrator = " << hex << calibrator_.get() << dec << endl
+ << " Builder = " << hex << builder_.get() << dec << endl
+ << " Engine = " << hex << engine_.get() << dec << endl
+ << " Logger = " << hex << &logger_ << dec << endl
+ << " Allocator = " << hex << allocator_.get() << dec << endl
+ << " Thread = " << hex << thr_.get() << dec << endl;
return oss.str();
}
- TRTInt8Calibrator* calibrator_;
+ std::unique_ptr<TRTInt8Calibrator> calibrator_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
- tensorflow::tensorrt::Logger* logger_;
+ tensorflow::tensorrt::Logger logger_;
// TODO(sami): Use threadpool threads!
- std::thread* thr_;
+ std::unique_ptr<std::thread> thr_;
};
class TRTWeightStore {