diff options
Diffstat (limited to 'tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h')
-rw-r--r-- | tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h | 15 |
1 files changed, 11 insertions, 4 deletions
diff --git a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h index 8830f7efe7..d77aa2c5ab 100644 --- a/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h +++ b/tensorflow/contrib/tensorrt/resources/trt_int8_calibrator.h @@ -24,7 +24,10 @@ limitations under the License. #if GOOGLE_CUDA #if GOOGLE_TENSORRT + +#include "cuda/include/cuda_runtime_api.h" #include "tensorrt/include/NvInfer.h" + namespace tensorflow { namespace tensorrt { // This class provides a 1 element queue to match TFs push model to @@ -39,8 +42,9 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { int getBatchSize() const override; bool getBatch(void* bindings[], const char* names[], int num_bindings) override; - bool setBatch(const std::unordered_map<string, void*>& data); - void setDone() { done_ = true; } + bool setBatch(const std::unordered_map<string, void*>& data, + const cudaStream_t stream); + void setDone(); const void* readCalibrationCache(std::size_t& length) override; void writeCalibrationCache(const void* ptr, std::size_t length) override; ~TRTInt8Calibrator(); @@ -55,11 +59,14 @@ struct TRTInt8Calibrator : public nvinfer1::IInt8EntropyCalibrator { const std::unordered_map<string, std::pair<void*, size_t>> dev_buffers_; // map to keep tensorrt input buffers and sizes keyed with // buffer names - std::atomic_bool calib_running_; + bool calib_running_; + bool batch_is_set_; string engine_name_; }; + } // namespace tensorrt } // namespace tensorflow -#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ + #endif #endif +#endif // TENSORFLOW_CONTRIB_TENSORRT_RESOURCES_TRT_INT8_CALIBRATOR_H_ |