aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/resources
diff options
context:
space:
mode:
authorGravatar Sami Kama <skama@nvidia.com>2018-06-11 21:04:03 -0700
committerGravatar Sami Kama <skama@nvidia.com>2018-06-11 21:04:03 -0700
commitae13b0560666df62967d87072e85619083a2f44b (patch)
tree1b5b403148d42ad359723a843ac86abc17860791 /tensorflow/contrib/tensorrt/resources
parentd1120e1334aae84bff40b3ee7cf0a3849936fe4b (diff)
Review changes
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources')
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc17
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h7
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h6
3 files changed, 11 insertions, 19 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 5adffdc3d1..695394156c 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -47,13 +47,11 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
done_(false),
calib_running_(false),
batch_is_set_(false),
- calibration_table(calib_data) {}
+ calibration_table_(calib_data) {}
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
- const cudaStream_t stream,
- tensorflow::core::RefCounted* rc) {
+ const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- tensorflow::core::ScopedUnref SC(rc);
while ((calib_running_ || batch_is_set_) &&
!done_) { // wait while calibration is running
cond_.wait(lock);
@@ -116,9 +114,9 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
}
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
- if (calibration_table.empty()) return nullptr;
- length = calibration_table.size();
- return calibration_table.data();
+ if (calibration_table_.empty()) return nullptr;
+ length = calibration_table_.size();
+ return calibration_table_.data();
}
void TRTInt8Calibrator::setDone() {
@@ -129,8 +127,9 @@ void TRTInt8Calibrator::setDone() {
void TRTInt8Calibrator::writeCalibrationCache(const void* ptr,
std::size_t length) {
- calibration_table = string((const char*)ptr, length);
- VLOG(1) << "Got calibration data for "<<engine_name_<<" @"<<ptr<<" length="<<length;
+ calibration_table_ = string((const char*)ptr, length);
+ VLOG(1) << "Got calibration data for " << engine_name_ << " @" << ptr
+ << " length=" << length;
}
TRTInt8Calibrator::~TRTInt8Calibrator() {
VLOG(1) << "Destroying calibrator for " << engine_name_;
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index eec9571418..6b59d52c70 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -47,12 +47,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool getBatch(void* bindings[], const char* names[],
int num_bindings) override;
bool setBatch(const std::unordered_map<string, void*>& data,
- const cudaStream_t stream,
- tensorflow::core::RefCounted* helper);
+ const cudaStream_t stream);
void setDone();
const void* readCalibrationCache(std::size_t& length) override;
void writeCalibrationCache(const void* ptr, std::size_t length) override;
- const string& getCalibrationTableAsString(){return calibration_table;}
+ const string& getCalibrationTableAsString() { return calibration_table_; }
~TRTInt8Calibrator();
private:
@@ -68,7 +67,7 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool calib_running_;
bool batch_is_set_;
string engine_name_;
- string calibration_table;
+ string calibration_table_;
};
} // namespace tensorrt
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 584d6baee5..022639dc01 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -47,17 +47,11 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_->destroy();
- builder_ = nullptr;
network_->destroy();
- network_ = nullptr;
engine_->destroy();
- engine_ = nullptr;
delete thr_;
- thr_ = nullptr;
delete logger_;
- logger_ = nullptr;
delete calibrator_;
- calibrator_ = nullptr;
}
string DebugString() override {