aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h')
-rw-r--r--tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h8
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
index 994312d7c3..65466c9741 100644
--- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
+++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h
@@ -36,10 +36,13 @@ namespace tensorrt {
struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
public:
+ // Construct a calibrator for future calibration.
TRTInt8Calibrator(
const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers,
int batch_size, string engine_name);
+ // Construct a finalized calibrator where we don't need to run calibration any
+ // more, as the calibration data is provided.
TRTInt8Calibrator(const string& calibration_data);
~TRTInt8Calibrator();
@@ -52,6 +55,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator {
bool setBatch(const std::unordered_map<string, void*>& data,
const cudaStream_t stream);
+ // Wait until the last batch is consumed by the calibrator and set done.
+ void waitAndSetDone();
+
+ // Notify that calibration is done and future batches provided by setBatch()
+ // will be ignored.
void setDone();
// If not null, calibration is skipped.