aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc')
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc37
1 files changed, 25 insertions, 12 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
index 32e81858b9..dab1dd9343 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.cc
@@ -36,13 +36,14 @@ TRTInt8Calibrator::TRTInt8Calibrator(
: batch_size_(batch_size),
done_(false),
dev_buffers_(dev_buffers),
+ // Make sure setBatch() waits until getBatch() is called (the first time).
calib_running_(true),
batch_is_set_(false),
engine_name_(engine_name) {}
TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
: batch_size_(0),
- done_(false),
+ done_(true),
calib_running_(false),
batch_is_set_(false),
calibration_table_(calib_data) {}
@@ -50,13 +51,14 @@ TRTInt8Calibrator::TRTInt8Calibrator(const string& calib_data)
bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream) {
tensorflow::mutex_lock lock(cond_mtx_);
- // wait while calibration is running.
- while ((calib_running_ || batch_is_set_) && !done_) {
- cond_.wait(lock);
- }
+
+ // Wait while the queue is full or calibration is running.
+ while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
if (done_) return false;
CHECK(!calib_running_ && !batch_is_set_);
VLOG(1) << "Set Batch Waiting finished";
+
+ // Sets the batch.
for (const auto it : data) {
auto devptr = dev_buffers_.find(it.first);
if (devptr == dev_buffers_.end()) {
@@ -76,8 +78,8 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
}
// TODO(Sami, aaorey): Find an alternative way!
- cudaStreamSynchronize(
- stream); // we have to wait for the stream before returning!
+ // we have to wait for the stream before returning!
+ cudaStreamSynchronize(stream);
batch_is_set_ = true;
cond_.notify_all();
return true;
@@ -86,21 +88,21 @@ bool TRTInt8Calibrator::setBatch(const std::unordered_map<string, void*>& data,
bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
int num_bindings) {
tensorflow::mutex_lock lock(cond_mtx_);
+ // Notify finish of last round of calibration.
calib_running_ = false;
cond_.notify_all();
- // wait until new batch arrives
- while ((!batch_is_set_ && !done_)) {
- cond_.wait(lock);
- }
+
+ // Wait until new batch arrives
+ while ((!batch_is_set_ && !done_)) cond_.wait(lock);
if (done_) return false;
+ // Gets the batch
for (int i = 0; i < num_bindings; i++) {
auto it = dev_buffers_.find(names[i]);
if (it == dev_buffers_.end()) {
LOG(FATAL) << "Calibration engine asked for unknown tensor name '"
<< names[i] << "' at position " << i;
}
-
bindings[i] = it->second.first;
}
batch_is_set_ = false;
@@ -108,6 +110,17 @@ bool TRTInt8Calibrator::getBatch(void** bindings, const char** names,
return true;
}
+void TRTInt8Calibrator::waitAndSetDone() {
+ tensorflow::mutex_lock lock(cond_mtx_);
+ // Wait while the queue is full or calibration is running, so we don't miss
+ // the last batch.
+ while ((calib_running_ || batch_is_set_) && !done_) cond_.wait(lock);
+ if (!done_) {
+ done_ = true;
+ cond_.notify_all();
+ }
+}
+
const void* TRTInt8Calibrator::readCalibrationCache(std::size_t& length) {
if (calibration_table_.empty()) return nullptr;
length = calibration_table_.size();