aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/resources
diff options
context:
space:
mode:
authorGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-20 10:20:32 -0700
committerGravatar gracehoney <31743510+aaroey@users.noreply.github.com>2018-06-20 10:20:32 -0700
commit5a8ff32bdb23b9ac4680f96b4b78493e3c4395ab (patch)
treef8e514b701eaa1c3d47b849b39c15247eaf73561 /tensorflow/contrib/tensorrt/resources
parent1bdcd6d624e4012cb9aec790a0d95076360bedb5 (diff)
Move the builder creation logic into ConvertGraphDefToEngine(), use unique_ptr for TRTCalibrationResource, and fix comments
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources')
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc13
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_resources.h26
2 files changed, 14 insertions, 25 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 9c1c306947..59ae860bc0 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -51,8 +51,8 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- while ((calib_running_ || batch_is_set_) &&
- !done_) { // wait while calibration is running
+ // wait while calibration is running.
+ while ((calib_running_ || batch_is_set_) && !done_) {
cond_.wait(lock);
}
if (done_) return false;
@@ -66,8 +66,6 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}
const auto& d = devptr->second;
- // TODO(aaroey): we should not use sync copy on default stream. Make sure
- // stream->ThenMemcpy() is used in future PRs.
// TODO(sami,aaroey): Need to figure out a way to ensure synchronization
// between stream, perhaps using a tensor?
auto status = cudaMemcpyAsync(d.first, it.second, d.second,
@@ -91,12 +89,11 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
tensorflow::mutex_lock lock(cond_mtx_);
calib_running_ = false;
cond_.notify_all();
- while ((!batch_is_set_ && !done_)) { // wait until new batch arrives
+ // wait until new batch arrives
+ while ((!batch_is_set_ && !done_)) {
cond_.wait(lock);
}
- if (done_) {
- return false;
- }
+ if (done_) return false;
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
diff --git a/tensorflow/contrib/tensorrt/resources/trt_resources.h b/tensorflow/contrib/tensorrt/resources/trt_resources.h
index 43734bbdd8..76863503bd 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_resources.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_resources.h
@@ -38,11 +38,6 @@ namespace tensorrt {
class TRTCalibrationResource : public tensorflow::ResourceBase {
public:
- TRTCalibrationResource()
- : calibrator_(nullptr),
- logger_(nullptr),
- thr_(nullptr) {}
-
~TRTCalibrationResource() {
VLOG(0) << "Destroying Calibration Resource " << std::endl << DebugString();
builder_.reset();
@@ -50,9 +45,6 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
// We need to manually destroy the builder and engine before the allocator
// is destroyed.
allocator_.reset();
- delete thr_;
- delete logger_;
- delete calibrator_;
}
string DebugString() override {
@@ -60,22 +52,22 @@ class TRTCalibrationResource : public tensorflow::ResourceBase {
using std::hex;
using std::dec;
using std::endl;
- oss << " Calibrator = " << hex << calibrator_ << dec << endl
- << " Builder = " << hex << builder_.get() << dec << endl
- << " Engine = " << hex << engine_.get() << dec << endl
- << " Logger = " << hex << logger_ << dec << endl
- << " Allocator = " << hex << allocator_.get() << dec << endl
- << " Thread = " << hex << thr_ << dec << endl;
+ oss << " Calibrator = " << hex << calibrator_.get() << dec << endl
+ << " Builder = " << hex << builder_.get() << dec << endl
+ << " Engine = " << hex << engine_.get() << dec << endl
+ << " Logger = " << hex << &logger_ << dec << endl
+ << " Allocator = " << hex << allocator_.get() << dec << endl
+ << " Thread = " << hex << thr_.get() << dec << endl;
return oss.str();
}
- TRTInt8Calibrator* calibrator_;
+ std::unique_ptr<TRTInt8Calibrator> calibrator_;
TrtUniquePtrType<nvinfer1::IBuilder> builder_;
TrtUniquePtrType<nvinfer1::ICudaEngine> engine_;
std::unique_ptr<nvinfer1::IGpuAllocator> allocator_;
- tensorflow::tensorrt::Logger* logger_;
+ tensorflow::tensorrt::Logger logger_;
// TODO(sami): Use threadpool threads!
- std::thread* thr_;
+ std::unique_ptr<std::thread> thr_;
};
class TRTWeightStore {