diff options
author | 2018-06-20 10:20:32 -0700 | |
---|---|---|
committer | 2018-06-20 10:20:32 -0700 | |
commit | 5a8ff32bdb23b9ac4680f96b4b78493e3c4395ab (patch) | |
tree | f8e514b701eaa1c3d47b849b39c15247eaf73561 /tensorflow/contrib/tensorrt/resources | |
parent | 1bdcd6d624e4012cb9aec790a0d95076360bedb5 (diff) |
Move the builder creation logic into ConvertGraphDefToEngine(), use unique_ptr for TRTCalibrationResource, and fix comments
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources')
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc | 13 | ||||
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_resources.h | 26 |
2 files changed, 14 insertions, 25 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index 9c1c306947..59ae860bc0 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -51,8 +51,8 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - while ((calib_running_ || batch_is_set_) && - !done_) { // wait while calibration is running + // wait while calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) { cond_.wait(lock); } if (done_) return false; @@ -66,8 +66,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, } const auto& d = devptr->second; - // TODO(aaroey): we should not use sync copy on default stream. Make sure - // stream->ThenMemcpy() is used in future PRs. // TODO(sami,aaroey): Need to figure out a way to ensure synchronization // between stream, perhaps using a tensor? auto status = cudaMemcpyAsync(d.first, it.second, d.second, @@ -91,12 +89,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, tensorflow::mutex_lock lock(cond_mtx_); calib_running_ = false; cond_.notify_all(); - while ((!batch_is_set_ && !done_)) { // wait until new batch arrives + // wait until new batch arrives + while ((!batch_is_set_ && !done_)) { cond_.wait(lock); } - if (done_) { - return false; - } + if (done_) return false; for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 43734bbdd8..76863503bd 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -38,11 +38,6 @@ namespace tensorrt { class TRTCalibrationResource : public tensorflow::ResourceBase { public: - TRTCalibrationResource() - : calibrator_(nullptr), - logger_(nullptr), - thr_(nullptr) {} - ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); builder_.reset(); @@ -50,9 +45,6 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { // We need to manually destroy the builder and engine before the allocator // is destroyed. allocator_.reset(); - delete thr_; - delete logger_; - delete calibrator_; } string DebugString() override { @@ -60,22 +52,22 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { using std::hex; using std::dec; using std::endl; - oss << " Calibrator = " << hex << calibrator_ << dec << endl - << " Builder = " << hex << builder_.get() << dec << endl - << " Engine = " << hex << engine_.get() << dec << endl - << " Logger = " << hex << logger_ << dec << endl - << " Allocator = " << hex << allocator_.get() << dec << endl - << " Thread = " << hex << thr_ << dec << endl; + oss << " Calibrator = " << hex << calibrator_.get() << dec << endl + << " Builder = " << hex << builder_.get() << dec << endl + << " Engine = " << hex << engine_.get() << dec << endl + << " Logger = " << hex << &logger_ << dec << endl + << " Allocator = " << hex << allocator_.get() << dec << endl + << " Thread = " << hex << thr_.get() << dec << endl; return oss.str(); } - TRTInt8Calibrator* calibrator_; + std::unique_ptr<TRTInt8Calibrator> calibrator_; TrtUniquePtrType<nvinfer1::IBuilder> builder_; TrtUniquePtrType<nvinfer1::ICudaEngine> engine_; std::unique_ptr<nvinfer1::IGpuAllocator> allocator_; - tensorflow::tensorrt::Logger* logger_; + tensorflow::tensorrt::Logger logger_; // TODO(sami): Use threadpool threads! - std::thread* thr_; + std::unique_ptr<std::thread> thr_; }; class TRTWeightStore { |