diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc')
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc | 37 |
1 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc index 32e81858b9..dab1dd9343 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc @@ -36,13 +36,14 @@ TRTInt8Calibrator::TRTInt8Calibrator( : batch_size_(batch_size), done_(false), dev_buffers_(dev_buffers), + // Make sure setBatch() waits until getBatch() is called (the first time). calib_running_(true), batch_is_set_(false), engine_name_(engine_name) {} TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) : batch_size_(0), - done_(false), + done_(true), calib_running_(false), batch_is_set_(false), calibration_table_(calib_data) {} @@ -50,13 +51,14 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data) bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, const cudaStream_t stream) { tensorflow::mutex_lock lock(cond_mtx_); - // wait while calibration is running. - while ((calib_running_ || batch_is_set_) && !done_) { - cond_.wait(lock); - } + + // Wait while the queue is full or calibration is running. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); if (done_) return false; CHECK(!calib_running_ && !batch_is_set_); VLOG(1) << "Set Batch Waiting finished"; + + // Sets the batch. for (const auto it : data) { auto devptr = dev_buffers_.find(it.first); if (devptr == dev_buffers_.end()) { @@ -76,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, } // TODO(Sami, aaorey): Find an alternative way! - cudaStreamSynchronize( - stream); // we have to wait for the stream before returning! + // we have to wait for the stream before returning! + cudaStreamSynchronize(stream); batch_is_set_ = true; cond_.notify_all(); return true; @@ -86,21 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data, bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, int num_bindings) { tensorflow::mutex_lock lock(cond_mtx_); + // Notify finish of last round of calibration. calib_running_ = false; cond_.notify_all(); - // wait until new batch arrives - while ((!batch_is_set_ && !done_)) { - cond_.wait(lock); - } + + // Wait until new batch arrives + while ((!batch_is_set_ && !done_)) cond_.wait(lock); if (done_) return false; + // Gets the batch for (int i = 0; i < num_bindings; i++) { auto it = dev_buffers_.find(names[i]); if (it == dev_buffers_.end()) { LOG(FATAL) << "Calibration engine asked for unknown tensor name '" << names[i] << "' at position " << i; } - bindings[i] = it->second.first; } batch_is_set_ = false; @@ -108,6 +110,17 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names, return true; } +void TRTInt8Calibrator::waitAndSetDone() { + tensorflow::mutex_lock lock(cond_mtx_); + // Wait while the queue is full or calibration is running, so we don't miss + // the last batch. + while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock); + if (!done_) { + done_ = true; + cond_.notify_all(); + } +} + const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) { if (calibration_table_.empty()) return nullptr; length = calibration_table_.size(); |