diff options
author | 2018-06-11 21:04:03 -0700 | |
---|---|---|
committer | 2018-06-11 21:04:03 -0700 | |
commit | ae13b0560666df62967d87072e85619083a2f44b (patch) | |
tree | 1b5b403148d42ad359723a843ac86abc17860791 /tensorflow/contrib/tensorrt/resources | |
parent | d1120e1334aae84bff40b3ee7cf0a3849936fe4b (diff) |
Review changes
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources')
3 files changed, 11 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index 5adffdc3d1..695394156c 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -47,13 +47,11 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) done_(false), calib_running_(false), batch_is_set_(false), - calibration_table(calib_data) {} + calibration_table_(calib_data) {} bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, - const cudaStream_t stream, - tensorflow::core::RefCounted* rc) { + const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - tensorflow::core::ScopedUnref SC(rc); while ((calib_running_ || batch_is_set_) && !done_) { // wait while calibration is running cond_.wait(lock); @@ -116,9 +114,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, } const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { - if (calibration_table.empty()) return nullptr; - length = calibration_table.size(); - return calibration_table.data(); + if (calibration_table_.empty()) return nullptr; + length = calibration_table_.size(); + return calibration_table_.data(); } void TRTInt8Calibrator::setDone() { @@ -129,8 +127,9 @@ void TRTInt8Calibrator::setDone() { void TRTInt8Calibrator::writeCalibrationCache(const void* ptr, std::size_t length) { - calibration_table = string((const char*)ptr, length); - VLOG(1) << "Got calibration data for "<<engine_name_<<" @"<<ptr<<" length="<<length; + calibration_table_ = string((const char*)ptr, length); + VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr + << " length=" << length; } TRTInt8Calibrator::~TRTInt8Calibrator() { VLOG(1) << "Destroying calibrator for " << engine_name_; diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index eec9571418..6b59d52c70 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -47,12 +47,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool getBatch(void* bindings[], const char* names[], int num_bindings) override; bool setBatch(const std::unordered_map<string, void*>& data, - const cudaStream_t stream, - tensorflow::core::RefCounted* helper); + const cudaStream_t stream); void setDone(); const void* readCalibrationCache(std::size_t& length) override; void writeCalibrationCache(const void* ptr, std::size_t length) override; - const string& getCalibrationTableAsString(){return calibration_table;} + const string& getCalibrationTableAsString() { return calibration_table_; } ~TRTInt8Calibrator(); private: @@ -68,7 +67,7 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool calib_running_; bool batch_is_set_; string engine_name_; - string calibration_table; + string calibration_table_; }; } // namespace tensorrt diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h index 584d6baee5..022639dc01 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_resources.h +++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h @@ -47,17 +47,11 @@ class TRTCalibrationResource : public tensorflow::ResourceBase { ~TRTCalibrationResource() { VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString(); builder_->destroy(); - builder_ = nullptr; network_->destroy(); - network_ = nullptr; engine_->destroy(); - engine_ = nullptr; delete thr_; - thr_ = nullptr; delete logger_; - logger_ = nullptr; delete calibrator_; - calibrator_ = nullptr; } string DebugString() override { |