diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index 994312d7c3..65466c9741 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -36,10 +36,13 @@ namespace tensorrt { struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { public: + // Construct a calibrator for future calibration. TRTInt8Calibrator( const std::unordered_map<string, std::pair<void*, size_t>>& dev_buffers, int batch_size, string engine_name); + // Construct a finalized calibrator where we don't need to run calibration any + // more, as the calibration data is provided. TRTInt8Calibrator(const string& calibration_data); ~TRTInt8Calibrator(); @@ -52,6 +55,11 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { bool setBatch(const std::unordered_map<string, void*>& data, const cudaStream_t stream); + // Wait until the last batch is consumed by the calibrator and set done. + void waitAndSetDone(); + + // Notify that calibration is done and future batches provided by setBatch() + // will be ignored. void setDone(); // If not null, calibration is skipped. |